feat (token_tracking): added tracking token to both query and insert endpoints --and consequently pipeline

This commit is contained in:
GGrassia 2025-11-26 17:00:04 +01:00
parent cd664de057
commit 3a2d3ddb9f
11 changed files with 588 additions and 104 deletions

View file

@ -397,6 +397,15 @@ def parse_args() -> argparse.Namespace:
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
)
# Token tracking configuration
parser.add_argument(
"--enable-token-tracking",
action="store_true",
default=get_env_value("ENABLE_TOKEN_TRACKING", False, bool),
help="Enable token usage tracking for LLM calls (default: from env or False)",
)
args.enable_token_tracking = get_env_value("ENABLE_TOKEN_TRACKING", False, bool)
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag

View file

@ -301,7 +301,10 @@ def create_app(args):
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
def create_optimized_openai_llm_func(
config_cache: LLMConfigCache, args, llm_timeout: int
config_cache: LLMConfigCache,
args,
llm_timeout: int,
enable_token_tracking=False,
):
"""Create optimized OpenAI LLM function with pre-processed configuration"""
@ -332,13 +335,19 @@ def create_app(args):
history_messages=history_messages,
base_url=args.llm_binding_host,
api_key=args.llm_binding_api_key,
token_tracker=getattr(app.state, "token_tracker", None)
if enable_token_tracking and hasattr(app.state, "token_tracker")
else None,
**kwargs,
)
return optimized_openai_alike_model_complete
def create_optimized_azure_openai_llm_func(
config_cache: LLMConfigCache, args, llm_timeout: int
config_cache: LLMConfigCache,
args,
llm_timeout: int,
enable_token_tracking=False,
):
"""Create optimized Azure OpenAI LLM function with pre-processed configuration"""
@ -359,8 +368,8 @@ def create_app(args):
# Use pre-processed configuration to avoid repeated parsing
kwargs["timeout"] = llm_timeout
if config_cache.openai_llm_options:
kwargs.update(config_cache.openai_llm_options)
if config_cache.azure_openai_llm_options:
kwargs.update(config_cache.azure_openai_llm_options)
return await azure_openai_complete_if_cache(
args.llm_model,
@ -370,12 +379,15 @@ def create_app(args):
base_url=args.llm_binding_host,
api_key=os.getenv("AZURE_OPENAI_API_KEY", args.llm_binding_api_key),
api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"),
token_tracker=getattr(app.state, "token_tracker", None)
if enable_token_tracking and hasattr(app.state, "token_tracker")
else None,
**kwargs,
)
return optimized_azure_openai_model_complete
def create_llm_model_func(binding: str):
def create_llm_model_func(binding: str, enable_token_tracking=False):
"""
Create LLM model function based on binding type.
Uses optimized functions for OpenAI bindings and lazy import for others.
@ -384,21 +396,42 @@ def create_app(args):
if binding == "lollms":
from lightrag.llm.lollms import lollms_model_complete
return lollms_model_complete
async def lollms_model_complete_with_tracker(*args, **kwargs):
# Add token tracker if enabled
if enable_token_tracking and hasattr(app.state, "token_tracker"):
kwargs["token_tracker"] = app.state.token_tracker
return await lollms_model_complete(*args, **kwargs)
return lollms_model_complete_with_tracker
elif binding == "ollama":
from lightrag.llm.ollama import ollama_model_complete
return ollama_model_complete
async def ollama_model_complete_with_tracker(*args, **kwargs):
# Add token tracker if enabled
if enable_token_tracking and hasattr(app.state, "token_tracker"):
kwargs["token_tracker"] = app.state.token_tracker
return await ollama_model_complete(*args, **kwargs)
return ollama_model_complete_with_tracker
elif binding == "aws_bedrock":
return bedrock_model_complete # Already defined locally
async def bedrock_model_complete_with_tracker(*args, **kwargs):
# Add token tracker if enabled
if enable_token_tracking and hasattr(app.state, "token_tracker"):
kwargs["token_tracker"] = app.state.token_tracker
return await bedrock_model_complete(*args, **kwargs)
return bedrock_model_complete_with_tracker
elif binding == "azure_openai":
# Use optimized function with pre-processed configuration
return create_optimized_azure_openai_llm_func(
config_cache, args, llm_timeout
config_cache, args, llm_timeout, enable_token_tracking
)
else: # openai and compatible
# Use optimized function with pre-processed configuration
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)
return create_optimized_openai_llm_func(
config_cache, args, llm_timeout, enable_token_tracking
)
except ImportError as e:
raise Exception(f"Failed to import {binding} LLM binding: {e}")
@ -422,7 +455,14 @@ def create_app(args):
return {}
def create_optimized_embedding_function(
config_cache: LLMConfigCache, binding, model, host, api_key, dimensions, args
config_cache: LLMConfigCache,
binding,
model,
host,
api_key,
dimensions,
args,
enable_token_tracking=False,
):
"""
Create optimized embedding function with pre-processed configuration for applicable bindings.
@ -435,7 +475,13 @@ def create_app(args):
from lightrag.llm.lollms import lollms_embed
return await lollms_embed(
texts, embed_model=model, host=host, api_key=api_key
texts,
embed_model=model,
host=host,
api_key=api_key,
token_tracker=getattr(app.state, "token_tracker", None)
if enable_token_tracking and hasattr(app.state, "token_tracker")
else None,
)
elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed
@ -455,26 +501,54 @@ def create_app(args):
host=host,
api_key=api_key,
options=ollama_options,
token_tracker=getattr(app.state, "token_tracker", None)
if enable_token_tracking and hasattr(app.state, "token_tracker")
else None,
)
elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed
return await azure_openai_embed(texts, model=model, api_key=api_key)
return await azure_openai_embed(
texts,
model=model,
api_key=api_key,
token_tracker=getattr(app.state, "token_tracker", None)
if enable_token_tracking and hasattr(app.state, "token_tracker")
else None,
)
elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed
return await bedrock_embed(texts, model=model)
return await bedrock_embed(
texts,
model=model,
token_tracker=getattr(app.state, "token_tracker", None)
if enable_token_tracking and hasattr(app.state, "token_tracker")
else None,
)
elif binding == "jina":
from lightrag.llm.jina import jina_embed
return await jina_embed(
texts, dimensions=dimensions, base_url=host, api_key=api_key
texts,
dimensions=dimensions,
base_url=host,
api_key=api_key,
token_tracker=getattr(app.state, "token_tracker", None)
if enable_token_tracking and hasattr(app.state, "token_tracker")
else None,
)
else: # openai and compatible
from lightrag.llm.openai import openai_embed
return await openai_embed(
texts, model=model, base_url=host, api_key=api_key
texts,
model=model,
base_url=host,
api_key=api_key,
token_tracker=getattr(app.state, "token_tracker", None)
if enable_token_tracking and hasattr(app.state, "token_tracker")
else None,
)
except ImportError as e:
raise Exception(f"Failed to import {binding} embedding: {e}")
@ -594,7 +668,9 @@ def create_app(args):
rag = LightRAG(
working_dir=args.working_dir,
workspace=args.workspace,
llm_model_func=create_llm_model_func(args.llm_binding),
llm_model_func=create_llm_model_func(
args.llm_binding, args.enable_token_tracking
),
llm_model_name=args.llm_model,
llm_model_max_async=args.max_async,
summary_max_tokens=args.summary_max_tokens,
@ -604,7 +680,16 @@ def create_app(args):
llm_model_kwargs=create_llm_model_kwargs(
args.llm_binding, args, llm_timeout
),
embedding_func=embedding_func,
embedding_func=create_optimized_embedding_function(
config_cache,
args.embedding_binding,
args.embedding_model,
args.embedding_binding_host,
args.embedding_binding_api_key,
args.embedding_dim,
args,
args.enable_token_tracking,
),
default_llm_timeout=llm_timeout,
default_embedding_timeout=embedding_timeout,
kv_storage=args.kv_storage,
@ -629,6 +714,17 @@ def create_app(args):
logger.error(f"Failed to initialize LightRAG: {e}")
raise
# Initialize token tracking if enabled
token_tracker = None
if args.enable_token_tracking:
from lightrag.utils import TokenTracker
token_tracker = TokenTracker()
logger.info("Token tracking enabled")
# Add token tracker to the app state for use in endpoints
app.state.token_tracker = token_tracker
# Add routes
app.include_router(
create_document_routes(
@ -637,7 +733,7 @@ def create_app(args):
api_key,
)
)
app.include_router(create_query_routes(rag, api_key, args.top_k))
app.include_router(create_query_routes(rag, api_key, args.top_k, token_tracker))
app.include_router(create_graph_routes(rag, api_key))
# Add Ollama API routes

View file

@ -243,6 +243,7 @@ class InsertResponse(BaseModel):
status: Status of the operation (success, duplicated, partial_success, failure)
message: Detailed message describing the operation result
track_id: Tracking ID for monitoring processing status
token_usage: Token usage statistics for the insertion (prompt_tokens, completion_tokens, total_tokens, call_count)
"""
status: Literal["success", "duplicated", "partial_success", "failure"] = Field(
@ -250,6 +251,10 @@ class InsertResponse(BaseModel):
)
message: str = Field(description="Message describing the operation result")
track_id: str = Field(description="Tracking ID for monitoring processing status")
token_usage: Optional[Dict[str, int]] = Field(
default=None,
description="Token usage statistics for the insertion (prompt_tokens, completion_tokens, total_tokens, call_count)",
)
class Config:
json_schema_extra = {
@ -257,10 +262,17 @@ class InsertResponse(BaseModel):
"status": "success",
"message": "File 'document.pdf' uploaded successfully. Processing will continue in background.",
"track_id": "upload_20250729_170612_abc123",
"token_usage": {
"prompt_tokens": 1250,
"completion_tokens": 450,
"total_tokens": 1700,
"call_count": 3,
},
}
}
class ClearDocumentsResponse(BaseModel):
"""Response model for document clearing operation

View file

@ -6,7 +6,7 @@ import json
import logging
from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from lightrag.base import QueryParam
from lightrag.types import MetadataFilter
from lightrag.api.utils_api import get_combined_auth_dependency
@ -157,6 +157,10 @@ class QueryResponse(BaseModel):
default=None,
description="Reference list (Disabled when include_references=False, /query/data always includes references.)",
)
token_usage: Optional[Dict[str, int]] = Field(
default=None,
description="Token usage statistics for the query (prompt_tokens, completion_tokens, total_tokens)",
)
class QueryDataResponse(BaseModel):
@ -168,6 +172,10 @@ class QueryDataResponse(BaseModel):
metadata: Dict[str, Any] = Field(
description="Query metadata including mode, keywords, and processing information"
)
token_usage: Optional[Dict[str, int]] = Field(
default=None,
description="Token usage statistics for the query (prompt_tokens, completion_tokens, total_tokens)",
)
class StreamChunkResponse(BaseModel):
@ -183,9 +191,18 @@ class StreamChunkResponse(BaseModel):
error: Optional[str] = Field(
default=None, description="Error message if processing fails"
)
token_usage: Optional[Dict[str, int]] = Field(
default=None,
description="Token usage statistics for the entire query (only in final chunk)",
)
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
def create_query_routes(
rag,
api_key: Optional[str] = None,
top_k: int = 60,
token_tracker: Optional[Any] = None,
):
combined_auth = get_combined_auth_dependency(api_key)
@router.post(
@ -220,7 +237,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
},
"examples": {
"with_references": {
"summary": "Response with references",
"summary": "Response with references and token usage",
"description": "Example response when include_references=True",
"value": {
"response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.",
@ -234,13 +251,25 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
"file_path": "/documents/machine_learning.txt",
},
],
"token_usage": {
"prompt_tokens": 245,
"completion_tokens": 87,
"total_tokens": 332,
"call_count": 1,
},
},
},
"without_references": {
"summary": "Response without references",
"summary": "Response without references but with token usage",
"description": "Example response when include_references=False",
"value": {
"response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving."
"response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.",
"token_usage": {
"prompt_tokens": 245,
"completion_tokens": 87,
"total_tokens": 332,
"call_count": 1,
},
},
},
"different_modes": {
@ -358,6 +387,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
- 500: Internal processing error (e.g., LLM service unavailable)
"""
try:
# Reset token tracker at start of query if available
if token_tracker:
token_tracker.reset()
param = request.to_query_params(
False
) # Ensure stream=False for non-streaming endpoint
@ -376,11 +409,22 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
if not response_content:
response_content = "No relevant context found for the query."
# Get token usage if available
token_usage = None
if token_tracker:
token_usage = token_tracker.get_usage()
# Return response with or without references based on request
if request.include_references:
return QueryResponse(response=response_content, references=references)
return QueryResponse(
response=response_content,
references=references,
token_usage=token_usage,
)
else:
return QueryResponse(response=response_content, references=None)
return QueryResponse(
response=response_content, references=None, token_usage=token_usage
)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@ -577,6 +621,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
Use streaming mode for real-time interfaces and non-streaming for batch processing.
"""
try:
# Reset token tracker at start of query if available
if token_tracker:
token_tracker.reset()
# Use the stream parameter from the request, defaulting to True if not specified
stream_mode = request.stream if request.stream is not None else True
param = request.to_query_params(stream_mode)
@ -605,6 +653,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
except Exception as e:
logging.error(f"Streaming error: {str(e)}")
yield f"{json.dumps({'error': str(e)})}\n"
# Add final token usage chunk if streaming and token tracker is available
if token_tracker and llm_response.get("is_streaming"):
yield f"{json.dumps({'token_usage': token_tracker.get_usage()})}\n"
else:
# Non-streaming mode: send complete response in one message
response_content = llm_response.get("content", "")
@ -616,6 +668,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
if request.include_references:
complete_response["references"] = references
# Add token usage if available
if token_tracker:
complete_response["token_usage"] = token_tracker.get_usage()
yield f"{json.dumps(complete_response)}\n"
return StreamingResponse(
@ -1022,18 +1078,31 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
as structured data analysis typically requires source attribution.
"""
try:
# Reset token tracker at start of query if available
if token_tracker:
token_tracker.reset()
param = request.to_query_params(False) # No streaming for data endpoint
response = await rag.aquery_data(request.query, param=param)
# Get token usage if available
token_usage = None
if token_tracker:
token_usage = token_tracker.get_usage()
# aquery_data returns the new format with status, message, data, and metadata
if isinstance(response, dict):
return QueryDataResponse(**response)
response_dict = dict(response)
response_dict["token_usage"] = token_usage
return QueryDataResponse(**response_dict)
else:
# Handle unexpected response format
return QueryDataResponse(
status="failure",
message="Invalid response type",
message="Unexpected response format",
data={},
metadata={},
token_usage=None,
)
except Exception as e:
trace_exception(e)

View file

@ -870,7 +870,8 @@ class LightRAG:
ids: str | list[str] | None = None,
file_paths: str | list[str] | None = None,
track_id: str | None = None,
) -> str:
token_tracker: TokenTracker | None = None,
) -> tuple[str, dict]:
"""Sync Insert documents with checkpoint support
Args:
@ -884,10 +885,10 @@ class LightRAG:
track_id: tracking ID for monitoring processing status, if not provided, will be generated
Returns:
str: tracking ID for monitoring processing status
tuple[str, dict]: (tracking ID for monitoring processing status, token usage statistics)
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(
result = loop.run_until_complete(
self.ainsert(
input,
split_by_character,
@ -895,8 +896,10 @@ class LightRAG:
ids,
file_paths,
track_id,
token_tracker,
)
)
return result
async def ainsert(
self,
@ -906,7 +909,8 @@ class LightRAG:
ids: str | list[str] | None = None,
file_paths: str | list[str] | None = None,
track_id: str | None = None,
) -> str:
token_tracker: TokenTracker | None = None,
) -> tuple[str, dict]:
"""Async Insert documents with checkpoint support
Args:
@ -930,9 +934,15 @@ class LightRAG:
await self.apipeline_process_enqueue_documents(
split_by_character,
split_by_character_only,
token_tracker,
)
return track_id
return track_id, token_tracker.get_usage() if token_tracker else {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"call_count": 0,
}
# TODO: deprecated, use insert instead
def insert_custom_chunks(
@ -1107,7 +1117,9 @@ class LightRAG:
"file_path"
], # Store file path in document status
"track_id": track_id, # Store track_id in document status
"metadata": metadata[i] if isinstance(metadata, list) and i < len(metadata) else metadata, # added provided custom metadata per document
"metadata": metadata[i]
if isinstance(metadata, list) and i < len(metadata)
else metadata, # added provided custom metadata per document
}
for i, (id_, content_data) in enumerate(contents.items())
}
@ -1363,6 +1375,7 @@ class LightRAG:
self,
split_by_character: str | None = None,
split_by_character_only: bool = False,
token_tracker: TokenTracker | None = None,
) -> None:
"""
Process pending documents by splitting them into chunks, processing
@ -1484,9 +1497,12 @@ class LightRAG:
pipeline_status: dict,
pipeline_status_lock: asyncio.Lock,
semaphore: asyncio.Semaphore,
token_tracker: TokenTracker | None = None,
) -> None:
"""Process single document"""
doc_metadata = getattr(status_doc, "metadata", None)
doc_metadata = getattr(status_doc, "metadata", None)
if doc_metadata is None:
doc_metadata = {}
file_extraction_stage_ok = False
async with semaphore:
nonlocal processed_count
@ -1498,7 +1514,6 @@ class LightRAG:
# Get file path from status document
file_path = getattr(
status_doc, "file_path", "unknown_source"
)
async with pipeline_status_lock:
@ -1561,7 +1576,9 @@ class LightRAG:
# Process document in two stages
# Stage 1: Process text chunks and docs (parallel execution)
doc_metadata["processing_start_time"] = processing_start_time
doc_metadata["processing_start_time"] = (
processing_start_time
)
doc_status_task = asyncio.create_task(
self.doc_status.upsert(
@ -1610,6 +1627,7 @@ class LightRAG:
doc_metadata,
pipeline_status,
pipeline_status_lock,
token_tracker,
)
)
await entity_relation_task
@ -1643,7 +1661,9 @@ class LightRAG:
processing_end_time = int(time.time())
# Update document status to failed
doc_metadata["processing_start_time"] = processing_start_time
doc_metadata["processing_start_time"] = (
processing_start_time
)
doc_metadata["processing_end_time"] = processing_end_time
await self.doc_status.upsert(
{
@ -1691,7 +1711,9 @@ class LightRAG:
doc_metadata["processing_start_time"] = (
processing_start_time
)
doc_metadata["processing_end_time"] = processing_end_time
doc_metadata["processing_end_time"] = (
processing_end_time
)
await self.doc_status.upsert(
{
@ -1748,7 +1770,9 @@ class LightRAG:
doc_metadata["processing_start_time"] = (
processing_start_time
)
doc_metadata["processing_end_time"] = processing_end_time
doc_metadata["processing_end_time"] = (
processing_end_time
)
await self.doc_status.upsert(
{
@ -1778,6 +1802,7 @@ class LightRAG:
pipeline_status,
pipeline_status_lock,
semaphore,
token_tracker,
)
)
@ -1827,6 +1852,7 @@ class LightRAG:
metadata: dict | None,
pipeline_status=None,
pipeline_status_lock=None,
token_tracker: TokenTracker | None = None,
) -> list:
try:
chunk_results = await extract_entities(
@ -1837,6 +1863,7 @@ class LightRAG:
pipeline_status_lock=pipeline_status_lock,
llm_response_cache=self.llm_response_cache,
text_chunks_storage=self.text_chunks,
token_tracker=token_tracker,
)
return chunk_results
except Exception as e:
@ -2061,14 +2088,18 @@ class LightRAG:
self,
query: str,
param: QueryParam = QueryParam(),
token_tracker: TokenTracker | None = None,
system_prompt: str | None = None,
) -> str | Iterator[str]:
) -> Any:
"""
Perform a sync query.
User query interface (backward compatibility wrapper).
Delegates to aquery() for asynchronous execution and returns the result.
Args:
query (str): The query to be executed.
param (QueryParam): Configuration parameters for query execution.
token_tracker (TokenTracker | None): Optional token tracker for monitoring usage.
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
Returns:
@ -2076,14 +2107,17 @@ class LightRAG:
"""
loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param, system_prompt)) # type: ignore
return loop.run_until_complete(
self.aquery(query, param, token_tracker, system_prompt)
) # type: ignore
async def aquery(
self,
query: str,
param: QueryParam = QueryParam(),
token_tracker: TokenTracker | None = None,
system_prompt: str | None = None,
) -> str | AsyncIterator[str]:
) -> Any:
"""
Perform a async query (backward compatibility wrapper).
@ -2102,7 +2136,7 @@ class LightRAG:
- Streaming: Returns AsyncIterator[str]
"""
# Call the new aquery_llm function to get complete results
result = await self.aquery_llm(query, param, system_prompt)
result = await self.aquery_llm(query, param, system_prompt, token_tracker)
# Extract and return only the LLM response for backward compatibility
llm_response = result.get("llm_response", {})
@ -2137,6 +2171,7 @@ class LightRAG:
self,
query: str,
param: QueryParam = QueryParam(),
token_tracker: TokenTracker | None = None,
) -> dict[str, Any]:
"""
Asynchronous data retrieval API: returns structured retrieval results without LLM generation.
@ -2330,6 +2365,7 @@ class LightRAG:
query: str,
param: QueryParam = QueryParam(),
system_prompt: str | None = None,
token_tracker: TokenTracker | None = None,
) -> dict[str, Any]:
"""
Asynchronous complete query API: returns structured retrieval results with LLM generation.
@ -2364,6 +2400,7 @@ class LightRAG:
hashing_kv=self.llm_response_cache,
system_prompt=system_prompt,
chunks_vdb=self.chunks_vdb,
token_tracker=token_tracker,
)
elif param.mode == "naive":
query_result = await naive_query(
@ -2373,6 +2410,7 @@ class LightRAG:
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=system_prompt,
token_tracker=token_tracker,
)
elif param.mode == "bypass":
# Bypass mode: directly use LLM without knowledge retrieval

View file

@ -46,6 +46,7 @@ async def azure_openai_complete_if_cache(
base_url: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
token_tracker: Any | None = None,
**kwargs,
):
if enable_cot:
@ -94,28 +95,73 @@ async def azure_openai_complete_if_cache(
)
if hasattr(response, "__aiter__"):
final_chunk_usage = None
accumulated_response = ""
async def inner():
async for chunk in response:
if len(chunk.choices) == 0:
continue
content = chunk.choices[0].delta.content
if content is None:
continue
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
yield content
nonlocal final_chunk_usage, accumulated_response
try:
async for chunk in response:
if len(chunk.choices) == 0:
continue
content = chunk.choices[0].delta.content
if content is None:
continue
accumulated_response += content
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
yield content
# Check for usage in the last chunk
if hasattr(chunk, "usage") and chunk.usage is not None:
final_chunk_usage = chunk.usage
except Exception as e:
logger.error(f"Error in Azure OpenAI stream response: {str(e)}")
raise
finally:
# After streaming is complete, track token usage
if token_tracker and final_chunk_usage:
# Use actual usage from the API
token_counts = {
"prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0),
"completion_tokens": getattr(
final_chunk_usage, "completion_tokens", 0
),
"total_tokens": getattr(final_chunk_usage, "total_tokens", 0),
}
token_tracker.add_usage(token_counts)
logger.debug(f"Azure OpenAI streaming token usage: {token_counts}")
elif token_tracker:
logger.debug(
"No usage information available in Azure OpenAI streaming response"
)
return inner()
else:
content = response.choices[0].message.content
if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8"))
# Track token usage for non-streaming response
if token_tracker and hasattr(response, "usage"):
token_counts = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
"completion_tokens": getattr(response.usage, "completion_tokens", 0),
"total_tokens": getattr(response.usage, "total_tokens", 0),
}
token_tracker.add_usage(token_counts)
logger.debug(f"Azure OpenAI non-streaming token usage: {token_counts}")
return content
async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
prompt,
system_prompt=None,
history_messages=[],
keyword_extraction=False,
token_tracker=None,
**kwargs,
) -> str:
kwargs.pop("keyword_extraction", None)
result = await azure_openai_complete_if_cache(
@ -123,6 +169,7 @@ async def azure_openai_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
token_tracker=token_tracker,
**kwargs,
)
return result
@ -142,6 +189,7 @@ async def azure_openai_embed(
base_url: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
token_tracker: Any | None = None,
) -> np.ndarray:
deployment = (
os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
@ -174,4 +222,14 @@ async def azure_openai_embed(
response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float"
)
# Track token usage for embeddings if token tracker is provided
if token_tracker and hasattr(response, "usage"):
token_counts = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
"total_tokens": getattr(response.usage, "total_tokens", 0),
}
token_tracker.add_usage(token_counts)
logger.debug(f"Azure OpenAI embedding token usage: {token_counts}")
return np.array([dp.embedding for dp in response.data])

View file

@ -48,6 +48,7 @@ async def bedrock_complete_if_cache(
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
token_tracker=None,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
if enable_cot:
@ -155,6 +156,18 @@ async def bedrock_complete_if_cache(
yield text
# Handle other event types that might indicate stream end
elif "messageStop" in event:
# Track token usage for streaming if token tracker is provided
if token_tracker and "usage" in event:
usage = event["usage"]
token_counts = {
"prompt_tokens": usage.get("inputTokens", 0),
"completion_tokens": usage.get("outputTokens", 0),
"total_tokens": usage.get("totalTokens", 0),
}
token_tracker.add_usage(token_counts)
logging.debug(
f"Bedrock streaming token usage: {token_counts}"
)
break
except Exception as e:
@ -228,6 +241,17 @@ async def bedrock_complete_if_cache(
if not content or content.strip() == "":
raise BedrockError("Received empty content from Bedrock API")
# Track token usage for non-streaming if token tracker is provided
if token_tracker and "usage" in response:
usage = response["usage"]
token_counts = {
"prompt_tokens": usage.get("inputTokens", 0),
"completion_tokens": usage.get("outputTokens", 0),
"total_tokens": usage.get("totalTokens", 0),
}
token_tracker.add_usage(token_counts)
logging.debug(f"Bedrock non-streaming token usage: {token_counts}")
return content
except Exception as e:
@ -239,7 +263,12 @@ async def bedrock_complete_if_cache(
# Generic Bedrock completion function
async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
prompt,
system_prompt=None,
history_messages=[],
keyword_extraction=False,
token_tracker=None,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
kwargs.pop("keyword_extraction", None)
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
@ -248,6 +277,7 @@ async def bedrock_complete(
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
token_tracker=token_tracker,
**kwargs,
)
return result
@ -265,6 +295,7 @@ async def bedrock_embed(
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
token_tracker=None,
) -> np.ndarray:
# Respect existing env; only set if a non-empty value is available
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id

View file

@ -108,6 +108,7 @@ async def lollms_model_complete(
history_messages=[],
enable_cot: bool = False,
keyword_extraction=False,
token_tracker=None,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
"""Complete function for lollms model generation."""
@ -135,7 +136,11 @@ async def lollms_model_complete(
async def lollms_embed(
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
texts: List[str],
embed_model=None,
base_url="http://localhost:9600",
token_tracker=None,
**kwargs,
) -> np.ndarray:
"""
Generate embeddings for a list of texts using lollms server.
@ -144,6 +149,7 @@ async def lollms_embed(
texts: List of strings to embed
embed_model: Model name (not used directly as lollms uses configured vectorizer)
base_url: URL of the lollms server
token_tracker: Optional token usage tracker for monitoring API usage
**kwargs: Additional arguments passed to the request
Returns:

View file

@ -39,6 +39,7 @@ async def _ollama_model_if_cache(
system_prompt=None,
history_messages=[],
enable_cot: bool = False,
token_tracker=None,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
if enable_cot:
@ -74,13 +75,47 @@ async def _ollama_model_if_cache(
"""cannot cache stream response and process reasoning"""
async def inner():
accumulated_response = ""
try:
async for chunk in response:
yield chunk["message"]["content"]
chunk_content = chunk["message"]["content"]
accumulated_response += chunk_content
yield chunk_content
except Exception as e:
logger.error(f"Error in stream response: {str(e)}")
raise
finally:
# Track token usage for streaming if token tracker is provided
if token_tracker:
# Estimate prompt tokens: roughly 4 characters per token for English text
prompt_text = ""
if system_prompt:
prompt_text += system_prompt + " "
prompt_text += (
" ".join(
[msg.get("content", "") for msg in history_messages]
)
+ " "
)
prompt_text += prompt
prompt_tokens = len(prompt_text) // 4 + (
1 if len(prompt_text) % 4 else 0
)
# Estimate completion tokens from accumulated response
completion_tokens = len(accumulated_response) // 4 + (
1 if len(accumulated_response) % 4 else 0
)
total_tokens = prompt_tokens + completion_tokens
token_tracker.add_usage(
{
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
)
try:
await ollama_client._client.aclose()
logger.debug("Successfully closed Ollama client for streaming")
@ -91,6 +126,35 @@ async def _ollama_model_if_cache(
else:
model_response = response["message"]["content"]
# Track token usage if token tracker is provided
# Note: Ollama doesn't provide token usage in chat responses, so we estimate
if token_tracker:
# Estimate prompt tokens: roughly 4 characters per token for English text
prompt_text = ""
if system_prompt:
prompt_text += system_prompt + " "
prompt_text += (
" ".join([msg.get("content", "") for msg in history_messages]) + " "
)
prompt_text += prompt
prompt_tokens = len(prompt_text) // 4 + (
1 if len(prompt_text) % 4 else 0
)
# Estimate completion tokens from response
completion_tokens = len(model_response) // 4 + (
1 if len(model_response) % 4 else 0
)
total_tokens = prompt_tokens + completion_tokens
token_tracker.add_usage(
{
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
)
"""
If the model also wraps its thoughts in a specific tag,
this information is not needed for the final
@ -126,6 +190,7 @@ async def ollama_model_complete(
history_messages=[],
enable_cot: bool = False,
keyword_extraction=False,
token_tracker=None,
**kwargs,
) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None)
@ -138,11 +203,14 @@ async def ollama_model_complete(
system_prompt=system_prompt,
history_messages=history_messages,
enable_cot=enable_cot,
token_tracker=token_tracker,
**kwargs,
)
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
async def ollama_embed(
texts: list[str], embed_model, token_tracker=None, **kwargs
) -> np.ndarray:
api_key = kwargs.pop("api_key", None)
headers = {
"Content-Type": "application/json",
@ -160,6 +228,21 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
data = await ollama_client.embed(
model=embed_model, input=texts, options=options
)
# Track token usage if token tracker is provided
# Note: Ollama doesn't provide token usage in embedding responses, so we estimate
if token_tracker:
# Estimate tokens: roughly 4 characters per token for English text
total_chars = sum(len(text) for text in texts)
estimated_tokens = total_chars // 4 + (1 if total_chars % 4 else 0)
token_tracker.add_usage(
{
"prompt_tokens": estimated_tokens,
"completion_tokens": 0,
"total_tokens": estimated_tokens,
}
)
return np.array(data["embeddings"])
except Exception as e:
logger.error(f"Error in ollama_embed: {str(e)}")

View file

@ -12,6 +12,7 @@ from .utils import (
logger,
compute_mdhash_id,
Tokenizer,
TokenTracker,
is_float_regex,
sanitize_and_normalize_extracted_text,
pack_user_ass_to_openai_messages,
@ -126,6 +127,7 @@ async def _handle_entity_relation_summary(
seperator: str,
global_config: dict,
llm_response_cache: BaseKVStorage | None = None,
token_tracker: TokenTracker | None = None,
) -> tuple[str, bool]:
"""Handle entity relation description summary using map-reduce approach.
@ -188,6 +190,7 @@ async def _handle_entity_relation_summary(
current_list,
global_config,
llm_response_cache,
token_tracker,
)
return final_summary, True # LLM was used for final summarization
@ -243,6 +246,7 @@ async def _handle_entity_relation_summary(
chunk,
global_config,
llm_response_cache,
token_tracker,
)
new_summaries.append(summary)
llm_was_used = True # Mark that LLM was used in reduce phase
@ -257,6 +261,7 @@ async def _summarize_descriptions(
description_list: list[str],
global_config: dict,
llm_response_cache: BaseKVStorage | None = None,
token_tracker: TokenTracker | None = None,
) -> str:
"""Helper function to summarize a list of descriptions using LLM.
@ -312,9 +317,10 @@ async def _summarize_descriptions(
# Use LLM function with cache (higher priority for summary generation)
summary, _ = await use_llm_func_with_cache(
use_prompt,
use_llm_func,
llm_response_cache=llm_response_cache,
use_llm_func=use_llm_func,
hashing_kv=llm_response_cache,
cache_type="summary",
token_tracker=token_tracker,
)
return summary
@ -405,7 +411,7 @@ async def _handle_single_relationship_extraction(
): # treat "relationship" and "relation" interchangeable
if len(record_attributes) > 1 and "relation" in record_attributes[0]:
logger.warning(
f"{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on REALTION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) >2 else 'N/A'}`"
f"{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on REALTION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) > 2 else 'N/A'}`"
)
logger.debug(record_attributes)
return None
@ -463,7 +469,6 @@ async def _handle_single_relationship_extraction(
file_path=file_path,
timestamp=timestamp,
metadata=metadata,
)
except ValueError as e:
@ -2037,6 +2042,7 @@ async def extract_entities(
pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None,
text_chunks_storage: BaseKVStorage | None = None,
token_tracker: TokenTracker | None = None,
) -> list:
use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@ -2150,12 +2156,13 @@ async def extract_entities(
final_result, timestamp = await use_llm_func_with_cache(
entity_extraction_user_prompt,
use_llm_func,
use_llm_func=use_llm_func,
system_prompt=entity_extraction_system_prompt,
llm_response_cache=llm_response_cache,
hashing_kv=llm_response_cache,
cache_type="extract",
chunk_id=chunk_key,
cache_keys_collector=cache_keys_collector,
token_tracker=token_tracker,
)
history = pack_user_ass_to_openai_messages(
@ -2177,16 +2184,16 @@ async def extract_entities(
if entity_extract_max_gleaning > 0:
glean_result, timestamp = await use_llm_func_with_cache(
entity_continue_extraction_user_prompt,
use_llm_func,
use_llm_func=use_llm_func,
system_prompt=entity_extraction_system_prompt,
llm_response_cache=llm_response_cache,
history_messages=history,
cache_type="extract",
chunk_id=chunk_key,
cache_keys_collector=cache_keys_collector,
token_tracker=token_tracker,
)
# Process gleaning result separately with file path and metadata
glean_nodes, glean_edges = await _process_extraction_result(
glean_result,
@ -2300,7 +2307,7 @@ async def extract_entities(
await asyncio.wait(pending)
# Add progress prefix to the exception message
progress_prefix = f"C[{processed_chunks+1}/{total_chunks}]"
progress_prefix = f"C[{processed_chunks + 1}/{total_chunks}]"
# Re-raise the original exception with a prefix
prefixed_exception = create_prefixed_exception(first_exception, progress_prefix)
@ -2324,6 +2331,7 @@ async def kg_query(
system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None,
return_raw_data: Literal[True] = False,
token_tracker: TokenTracker | None = None,
) -> dict[str, Any]: ...
@ -2341,6 +2349,7 @@ async def kg_query(
chunks_vdb: BaseVectorStorage = None,
metadata_filters: list | None = None,
return_raw_data: Literal[False] = False,
token_tracker: TokenTracker | None = None,
) -> str | AsyncIterator[str]: ...
@ -2355,6 +2364,7 @@ async def kg_query(
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None,
token_tracker: TokenTracker | None = None,
) -> QueryResult:
"""
Execute knowledge graph query and return unified QueryResult object.
@ -2422,7 +2432,7 @@ async def kg_query(
return QueryResult(content=cached_response)
hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv
query, query_param, global_config, hashing_kv, token_tracker
)
logger.debug(f"High-level keywords: {hl_keywords}")
@ -2526,6 +2536,7 @@ async def kg_query(
history_messages=query_param.conversation_history,
enable_cot=True,
stream=query_param.stream,
token_tracker=token_tracker,
)
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
@ -2583,6 +2594,7 @@ async def get_keywords_from_query(
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
token_tracker: TokenTracker | None = None,
) -> tuple[list[str], list[str]]:
"""
Retrieves high-level and low-level keywords for RAG operations.
@ -2605,7 +2617,7 @@ async def get_keywords_from_query(
# Extract keywords using extract_keywords_only function which already supports conversation history
hl_keywords, ll_keywords = await extract_keywords_only(
query, query_param, global_config, hashing_kv
query, query_param, global_config, hashing_kv, token_tracker
)
return hl_keywords, ll_keywords
@ -2615,6 +2627,7 @@ async def extract_keywords_only(
param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
token_tracker: TokenTracker | None = None,
) -> tuple[list[str], list[str]]:
"""
Extract high-level and low-level keywords from the given 'text' using the LLM.
@ -2668,7 +2681,9 @@ async def extract_keywords_only(
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
result = await use_model_func(kw_prompt, keyword_extraction=True)
result = await use_model_func(
kw_prompt, keyword_extraction=True, token_tracker=token_tracker
)
# 5. Parse out JSON from the LLM response
result = remove_think_tags(result)
@ -2746,7 +2761,10 @@ async def _get_vector_context(
cosine_threshold = chunks_vdb.cosine_better_than_threshold
results = await chunks_vdb.query(
query, top_k=search_top_k, query_embedding=query_embedding, metadata_filter=query_param.metadata_filter
query,
top_k=search_top_k,
query_embedding=query_embedding,
metadata_filter=query_param.metadata_filter,
)
if not results:
logger.info(
@ -2763,7 +2781,7 @@ async def _get_vector_context(
"file_path": result.get("file_path", "unknown_source"),
"source_type": "vector", # Mark the source type
"chunk_id": result.get("id"), # Add chunk_id for deduplication
"metadata": result.get("metadata")
"metadata": result.get("metadata"),
}
valid_chunks.append(chunk_with_metadata)
@ -3529,8 +3547,9 @@ async def _get_node_data(
f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})"
)
results = await entities_vdb.query(query, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter)
results = await entities_vdb.query(
query, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter
)
if not len(results):
return [], []
@ -3538,7 +3557,6 @@ async def _get_node_data(
# Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results]
# Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results]
@ -3810,7 +3828,9 @@ async def _get_edge_data(
f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})"
)
results = await relationships_vdb.query(keywords, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter)
results = await relationships_vdb.query(
keywords, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter
)
if not len(results):
return [], []
@ -4104,6 +4124,7 @@ async def naive_query(
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
return_raw_data: Literal[True] = True,
token_tracker: TokenTracker | None = None,
) -> dict[str, Any]: ...
@ -4116,6 +4137,7 @@ async def naive_query(
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
return_raw_data: Literal[False] = False,
token_tracker: TokenTracker | None = None,
) -> str | AsyncIterator[str]: ...
@ -4126,6 +4148,7 @@ async def naive_query(
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
token_tracker: TokenTracker | None = None,
) -> QueryResult:
"""
Execute naive query and return unified QueryResult object.
@ -4321,6 +4344,7 @@ async def naive_query(
history_messages=query_param.conversation_history,
enable_cot=True,
stream=query_param.stream,
token_tracker=token_tracker,
)
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):

View file

@ -4,6 +4,7 @@ import weakref
import asyncio
import html
import csv
import contextvars
import json
import logging
import logging.handlers
@ -507,6 +508,7 @@ def priority_limit_async_func_call(
task_id,
args,
kwargs,
ctx,
) = await asyncio.wait_for(queue.get(), timeout=1.0)
except asyncio.TimeoutError:
continue
@ -536,11 +538,15 @@ def priority_limit_async_func_call(
try:
# Execute function with timeout protection
if max_execution_timeout is not None:
# Run the function in the captured context
task = ctx.run(lambda: asyncio.create_task(func(*args, **kwargs)))
result = await asyncio.wait_for(
func(*args, **kwargs), timeout=max_execution_timeout
task, timeout=max_execution_timeout
)
else:
result = await func(*args, **kwargs)
# Run the function in the captured context
task = ctx.run(lambda: asyncio.create_task(func(*args, **kwargs)))
result = await task
# Set result if future is still valid
if not task_state.future.done():
@ -791,6 +797,9 @@ def priority_limit_async_func_call(
future=future, start_time=asyncio.get_event_loop().time()
)
# Capture current context
ctx = contextvars.copy_context()
try:
# Register task state
async with task_states_lock:
@ -809,13 +818,13 @@ def priority_limit_async_func_call(
if _queue_timeout is not None:
await asyncio.wait_for(
queue.put(
(_priority, current_count, task_id, args, kwargs)
(_priority, current_count, task_id, args, kwargs, ctx)
),
timeout=_queue_timeout,
)
else:
await queue.put(
(_priority, current_count, task_id, args, kwargs)
(_priority, current_count, task_id, args, kwargs, ctx)
)
except asyncio.TimeoutError:
raise QueueFullError(
@ -1472,8 +1481,7 @@ async def aexport_data(
else:
raise ValueError(
f"Unsupported file format: {file_format}. "
f"Choose from: csv, excel, md, txt"
f"Unsupported file format: {file_format}. Choose from: csv, excel, md, txt"
)
if file_format is not None:
print(f"Data exported to: {output_path} with format: {file_format}")
@ -1601,6 +1609,8 @@ async def use_llm_func_with_cache(
cache_type: str = "extract",
chunk_id: str | None = None,
cache_keys_collector: list = None,
hashing_kv: "BaseKVStorage | None" = None,
token_tracker=None,
) -> tuple[str, int]:
"""Call LLM function with cache support and text sanitization
@ -1685,7 +1695,10 @@ async def use_llm_func_with_cache(
kwargs["max_tokens"] = max_tokens
res: str = await use_llm_func(
safe_user_prompt, system_prompt=safe_system_prompt, **kwargs
safe_user_prompt,
system_prompt=safe_system_prompt,
token_tracker=token_tracker,
**kwargs,
)
res = remove_think_tags(res)
@ -1720,7 +1733,10 @@ async def use_llm_func_with_cache(
try:
res = await use_llm_func(
safe_user_prompt, system_prompt=safe_system_prompt, **kwargs
safe_user_prompt,
system_prompt=safe_system_prompt,
token_tracker=token_tracker,
**kwargs,
)
except Exception as e:
# Add [LLM func] prefix to error message
@ -2216,52 +2232,74 @@ async def pick_by_vector_similarity(
return all_chunk_ids[:num_of_chunks]
from contextvars import ContextVar
class TokenTracker:
"""Track token usage for LLM calls."""
"""Track token usage for LLM calls using ContextVars for concurrency support."""
_usage_var: ContextVar[dict] = ContextVar("token_usage", default=None)
def __init__(self):
self.reset()
# No instance state needed as we use ContextVar
pass
def __enter__(self):
self.reset()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
print(self)
# Optional: Log usage on exit if needed
pass
def reset(self):
self.prompt_tokens = 0
self.completion_tokens = 0
self.total_tokens = 0
self.call_count = 0
"""Initialize/Reset token usage for the current context."""
self._usage_var.set(
{
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"call_count": 0,
}
)
def add_usage(self, token_counts):
def _get_current_usage(self) -> dict:
"""Get the usage dict for the current context, initializing if necessary."""
usage = self._usage_var.get()
if usage is None:
usage = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"call_count": 0,
}
self._usage_var.set(usage)
return usage
def add_usage(self, token_counts: dict):
"""Add token usage from one LLM call.
Args:
token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens
"""
self.prompt_tokens += token_counts.get("prompt_tokens", 0)
self.completion_tokens += token_counts.get("completion_tokens", 0)
usage = self._get_current_usage()
usage["prompt_tokens"] += token_counts.get("prompt_tokens", 0)
usage["completion_tokens"] += token_counts.get("completion_tokens", 0)
# If total_tokens is provided, use it directly; otherwise calculate the sum
if "total_tokens" in token_counts:
self.total_tokens += token_counts["total_tokens"]
usage["total_tokens"] += token_counts["total_tokens"]
else:
self.total_tokens += token_counts.get(
usage["total_tokens"] += token_counts.get(
"prompt_tokens", 0
) + token_counts.get("completion_tokens", 0)
self.call_count += 1
usage["call_count"] += 1
def get_usage(self):
"""Get current usage statistics."""
return {
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"total_tokens": self.total_tokens,
"call_count": self.call_count,
}
return self._get_current_usage().copy()
def __str__(self):
usage = self.get_usage()
@ -2273,6 +2311,26 @@ class TokenTracker:
)
def estimate_embedding_tokens(texts: list[str], tokenizer: Tokenizer) -> int:
"""Estimate tokens for embedding operations based on text length.
Most embedding APIs don't return token counts, so we estimate based on
the tokenizer encoding. This provides a reasonable approximation for tracking.
Args:
texts: List of text strings to be embedded
tokenizer: Tokenizer instance for encoding
Returns:
Estimated total token count for all texts
"""
total = 0
for text in texts:
if text: # Skip empty strings
total += len(tokenizer.encode(text))
return total
async def apply_rerank_if_enabled(
query: str,
retrieved_docs: list[dict],