From 3a2d3ddb9fce101b04f82a6b299e29cdebc4585c Mon Sep 17 00:00:00 2001 From: GGrassia Date: Wed, 26 Nov 2025 17:00:04 +0100 Subject: [PATCH] feat (token_tracking): added tracking token to both query and insert endpoints --and consequently pipeline --- lightrag/api/config.py | 9 ++ lightrag/api/lightrag_server.py | 134 ++++++++++++++++++++---- lightrag/api/routers/document_routes.py | 12 +++ lightrag/api/routers/query_routes.py | 87 +++++++++++++-- lightrag/lightrag.py | 72 ++++++++++--- lightrag/llm/azure_openai.py | 78 ++++++++++++-- lightrag/llm/bedrock.py | 33 +++++- lightrag/llm/lollms.py | 8 +- lightrag/llm/ollama.py | 87 ++++++++++++++- lightrag/operate.py | 60 +++++++---- lightrag/utils.py | 112 +++++++++++++++----- 11 files changed, 588 insertions(+), 104 deletions(-) diff --git a/lightrag/api/config.py b/lightrag/api/config.py index de569f47..e17c57e7 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -397,6 +397,15 @@ def parse_args() -> argparse.Namespace: "EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int ) + # Token tracking configuration + parser.add_argument( + "--enable-token-tracking", + action="store_true", + default=get_env_value("ENABLE_TOKEN_TRACKING", False, bool), + help="Enable token usage tracking for LLM calls (default: from env or False)", + ) + args.enable_token_tracking = get_env_value("ENABLE_TOKEN_TRACKING", False, bool) + ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index fb0f7985..99316ba3 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -301,7 +301,10 @@ def create_app(args): Path(args.working_dir).mkdir(parents=True, exist_ok=True) def create_optimized_openai_llm_func( - config_cache: LLMConfigCache, args, llm_timeout: int + config_cache: LLMConfigCache, + args, + llm_timeout: int, + enable_token_tracking=False, ): """Create optimized OpenAI LLM function with pre-processed configuration""" @@ -332,13 +335,19 @@ def create_app(args): history_messages=history_messages, base_url=args.llm_binding_host, api_key=args.llm_binding_api_key, + token_tracker=getattr(app.state, "token_tracker", None) + if enable_token_tracking and hasattr(app.state, "token_tracker") + else None, **kwargs, ) return optimized_openai_alike_model_complete def create_optimized_azure_openai_llm_func( - config_cache: LLMConfigCache, args, llm_timeout: int + config_cache: LLMConfigCache, + args, + llm_timeout: int, + enable_token_tracking=False, ): """Create optimized Azure OpenAI LLM function with pre-processed configuration""" @@ -359,8 +368,8 @@ def create_app(args): # Use pre-processed configuration to avoid repeated parsing kwargs["timeout"] = llm_timeout - if config_cache.openai_llm_options: - kwargs.update(config_cache.openai_llm_options) + if config_cache.azure_openai_llm_options: + kwargs.update(config_cache.azure_openai_llm_options) return await azure_openai_complete_if_cache( args.llm_model, @@ -370,12 +379,15 @@ def create_app(args): base_url=args.llm_binding_host, api_key=os.getenv("AZURE_OPENAI_API_KEY", args.llm_binding_api_key), api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"), + token_tracker=getattr(app.state, "token_tracker", None) + if enable_token_tracking and hasattr(app.state, "token_tracker") + else None, **kwargs, ) return optimized_azure_openai_model_complete - def create_llm_model_func(binding: str): + def create_llm_model_func(binding: str, enable_token_tracking=False): """ Create LLM model function based on binding type. Uses optimized functions for OpenAI bindings and lazy import for others. @@ -384,21 +396,42 @@ def create_app(args): if binding == "lollms": from lightrag.llm.lollms import lollms_model_complete - return lollms_model_complete + async def lollms_model_complete_with_tracker(*args, **kwargs): + # Add token tracker if enabled + if enable_token_tracking and hasattr(app.state, "token_tracker"): + kwargs["token_tracker"] = app.state.token_tracker + return await lollms_model_complete(*args, **kwargs) + + return lollms_model_complete_with_tracker elif binding == "ollama": from lightrag.llm.ollama import ollama_model_complete - return ollama_model_complete + async def ollama_model_complete_with_tracker(*args, **kwargs): + # Add token tracker if enabled + if enable_token_tracking and hasattr(app.state, "token_tracker"): + kwargs["token_tracker"] = app.state.token_tracker + return await ollama_model_complete(*args, **kwargs) + + return ollama_model_complete_with_tracker elif binding == "aws_bedrock": - return bedrock_model_complete # Already defined locally + + async def bedrock_model_complete_with_tracker(*args, **kwargs): + # Add token tracker if enabled + if enable_token_tracking and hasattr(app.state, "token_tracker"): + kwargs["token_tracker"] = app.state.token_tracker + return await bedrock_model_complete(*args, **kwargs) + + return bedrock_model_complete_with_tracker elif binding == "azure_openai": # Use optimized function with pre-processed configuration return create_optimized_azure_openai_llm_func( - config_cache, args, llm_timeout + config_cache, args, llm_timeout, enable_token_tracking ) else: # openai and compatible # Use optimized function with pre-processed configuration - return create_optimized_openai_llm_func(config_cache, args, llm_timeout) + return create_optimized_openai_llm_func( + config_cache, args, llm_timeout, enable_token_tracking + ) except ImportError as e: raise Exception(f"Failed to import {binding} LLM binding: {e}") @@ -422,7 +455,14 @@ def create_app(args): return {} def create_optimized_embedding_function( - config_cache: LLMConfigCache, binding, model, host, api_key, dimensions, args + config_cache: LLMConfigCache, + binding, + model, + host, + api_key, + dimensions, + args, + enable_token_tracking=False, ): """ Create optimized embedding function with pre-processed configuration for applicable bindings. @@ -435,7 +475,13 @@ def create_app(args): from lightrag.llm.lollms import lollms_embed return await lollms_embed( - texts, embed_model=model, host=host, api_key=api_key + texts, + embed_model=model, + host=host, + api_key=api_key, + token_tracker=getattr(app.state, "token_tracker", None) + if enable_token_tracking and hasattr(app.state, "token_tracker") + else None, ) elif binding == "ollama": from lightrag.llm.ollama import ollama_embed @@ -455,26 +501,54 @@ def create_app(args): host=host, api_key=api_key, options=ollama_options, + token_tracker=getattr(app.state, "token_tracker", None) + if enable_token_tracking and hasattr(app.state, "token_tracker") + else None, ) elif binding == "azure_openai": from lightrag.llm.azure_openai import azure_openai_embed - return await azure_openai_embed(texts, model=model, api_key=api_key) + return await azure_openai_embed( + texts, + model=model, + api_key=api_key, + token_tracker=getattr(app.state, "token_tracker", None) + if enable_token_tracking and hasattr(app.state, "token_tracker") + else None, + ) elif binding == "aws_bedrock": from lightrag.llm.bedrock import bedrock_embed - return await bedrock_embed(texts, model=model) + return await bedrock_embed( + texts, + model=model, + token_tracker=getattr(app.state, "token_tracker", None) + if enable_token_tracking and hasattr(app.state, "token_tracker") + else None, + ) elif binding == "jina": from lightrag.llm.jina import jina_embed return await jina_embed( - texts, dimensions=dimensions, base_url=host, api_key=api_key + texts, + dimensions=dimensions, + base_url=host, + api_key=api_key, + token_tracker=getattr(app.state, "token_tracker", None) + if enable_token_tracking and hasattr(app.state, "token_tracker") + else None, ) else: # openai and compatible from lightrag.llm.openai import openai_embed return await openai_embed( - texts, model=model, base_url=host, api_key=api_key + texts, + model=model, + base_url=host, + api_key=api_key, + token_tracker=getattr(app.state, "token_tracker", None) + if enable_token_tracking and hasattr(app.state, "token_tracker") + else None, ) except ImportError as e: raise Exception(f"Failed to import {binding} embedding: {e}") @@ -594,7 +668,9 @@ def create_app(args): rag = LightRAG( working_dir=args.working_dir, workspace=args.workspace, - llm_model_func=create_llm_model_func(args.llm_binding), + llm_model_func=create_llm_model_func( + args.llm_binding, args.enable_token_tracking + ), llm_model_name=args.llm_model, llm_model_max_async=args.max_async, summary_max_tokens=args.summary_max_tokens, @@ -604,7 +680,16 @@ def create_app(args): llm_model_kwargs=create_llm_model_kwargs( args.llm_binding, args, llm_timeout ), - embedding_func=embedding_func, + embedding_func=create_optimized_embedding_function( + config_cache, + args.embedding_binding, + args.embedding_model, + args.embedding_binding_host, + args.embedding_binding_api_key, + args.embedding_dim, + args, + args.enable_token_tracking, + ), default_llm_timeout=llm_timeout, default_embedding_timeout=embedding_timeout, kv_storage=args.kv_storage, @@ -629,6 +714,17 @@ def create_app(args): logger.error(f"Failed to initialize LightRAG: {e}") raise + # Initialize token tracking if enabled + token_tracker = None + if args.enable_token_tracking: + from lightrag.utils import TokenTracker + + token_tracker = TokenTracker() + logger.info("Token tracking enabled") + + # Add token tracker to the app state for use in endpoints + app.state.token_tracker = token_tracker + # Add routes app.include_router( create_document_routes( @@ -637,7 +733,7 @@ def create_app(args): api_key, ) ) - app.include_router(create_query_routes(rag, api_key, args.top_k)) + app.include_router(create_query_routes(rag, api_key, args.top_k, token_tracker)) app.include_router(create_graph_routes(rag, api_key)) # Add Ollama API routes diff --git a/lightrag/api/routers/document_routes.py b/lightrag/api/routers/document_routes.py index 6ca873e8..f8bbb9ab 100644 --- a/lightrag/api/routers/document_routes.py +++ b/lightrag/api/routers/document_routes.py @@ -243,6 +243,7 @@ class InsertResponse(BaseModel): status: Status of the operation (success, duplicated, partial_success, failure) message: Detailed message describing the operation result track_id: Tracking ID for monitoring processing status + token_usage: Token usage statistics for the insertion (prompt_tokens, completion_tokens, total_tokens, call_count) """ status: Literal["success", "duplicated", "partial_success", "failure"] = Field( @@ -250,6 +251,10 @@ class InsertResponse(BaseModel): ) message: str = Field(description="Message describing the operation result") track_id: str = Field(description="Tracking ID for monitoring processing status") + token_usage: Optional[Dict[str, int]] = Field( + default=None, + description="Token usage statistics for the insertion (prompt_tokens, completion_tokens, total_tokens, call_count)", + ) class Config: json_schema_extra = { @@ -257,10 +262,17 @@ class InsertResponse(BaseModel): "status": "success", "message": "File 'document.pdf' uploaded successfully. Processing will continue in background.", "track_id": "upload_20250729_170612_abc123", + "token_usage": { + "prompt_tokens": 1250, + "completion_tokens": 450, + "total_tokens": 1700, + "call_count": 3, + }, } } + class ClearDocumentsResponse(BaseModel): """Response model for document clearing operation diff --git a/lightrag/api/routers/query_routes.py b/lightrag/api/routers/query_routes.py index 1ebe7cbe..046a1ea4 100644 --- a/lightrag/api/routers/query_routes.py +++ b/lightrag/api/routers/query_routes.py @@ -6,7 +6,7 @@ import json import logging from typing import Any, Dict, List, Literal, Optional -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Request, Response from lightrag.base import QueryParam from lightrag.types import MetadataFilter from lightrag.api.utils_api import get_combined_auth_dependency @@ -157,6 +157,10 @@ class QueryResponse(BaseModel): default=None, description="Reference list (Disabled when include_references=False, /query/data always includes references.)", ) + token_usage: Optional[Dict[str, int]] = Field( + default=None, + description="Token usage statistics for the query (prompt_tokens, completion_tokens, total_tokens)", + ) class QueryDataResponse(BaseModel): @@ -168,6 +172,10 @@ class QueryDataResponse(BaseModel): metadata: Dict[str, Any] = Field( description="Query metadata including mode, keywords, and processing information" ) + token_usage: Optional[Dict[str, int]] = Field( + default=None, + description="Token usage statistics for the query (prompt_tokens, completion_tokens, total_tokens)", + ) class StreamChunkResponse(BaseModel): @@ -183,9 +191,18 @@ class StreamChunkResponse(BaseModel): error: Optional[str] = Field( default=None, description="Error message if processing fails" ) + token_usage: Optional[Dict[str, int]] = Field( + default=None, + description="Token usage statistics for the entire query (only in final chunk)", + ) -def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): +def create_query_routes( + rag, + api_key: Optional[str] = None, + top_k: int = 60, + token_tracker: Optional[Any] = None, +): combined_auth = get_combined_auth_dependency(api_key) @router.post( @@ -220,7 +237,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): }, "examples": { "with_references": { - "summary": "Response with references", + "summary": "Response with references and token usage", "description": "Example response when include_references=True", "value": { "response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.", @@ -234,13 +251,25 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): "file_path": "/documents/machine_learning.txt", }, ], + "token_usage": { + "prompt_tokens": 245, + "completion_tokens": 87, + "total_tokens": 332, + "call_count": 1, + }, }, }, "without_references": { - "summary": "Response without references", + "summary": "Response without references but with token usage", "description": "Example response when include_references=False", "value": { - "response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving." + "response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.", + "token_usage": { + "prompt_tokens": 245, + "completion_tokens": 87, + "total_tokens": 332, + "call_count": 1, + }, }, }, "different_modes": { @@ -358,6 +387,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): - 500: Internal processing error (e.g., LLM service unavailable) """ try: + # Reset token tracker at start of query if available + if token_tracker: + token_tracker.reset() + param = request.to_query_params( False ) # Ensure stream=False for non-streaming endpoint @@ -376,11 +409,22 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): if not response_content: response_content = "No relevant context found for the query." + # Get token usage if available + token_usage = None + if token_tracker: + token_usage = token_tracker.get_usage() + # Return response with or without references based on request if request.include_references: - return QueryResponse(response=response_content, references=references) + return QueryResponse( + response=response_content, + references=references, + token_usage=token_usage, + ) else: - return QueryResponse(response=response_content, references=None) + return QueryResponse( + response=response_content, references=None, token_usage=token_usage + ) except Exception as e: trace_exception(e) raise HTTPException(status_code=500, detail=str(e)) @@ -577,6 +621,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): Use streaming mode for real-time interfaces and non-streaming for batch processing. """ try: + # Reset token tracker at start of query if available + if token_tracker: + token_tracker.reset() + # Use the stream parameter from the request, defaulting to True if not specified stream_mode = request.stream if request.stream is not None else True param = request.to_query_params(stream_mode) @@ -605,6 +653,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): except Exception as e: logging.error(f"Streaming error: {str(e)}") yield f"{json.dumps({'error': str(e)})}\n" + + # Add final token usage chunk if streaming and token tracker is available + if token_tracker and llm_response.get("is_streaming"): + yield f"{json.dumps({'token_usage': token_tracker.get_usage()})}\n" else: # Non-streaming mode: send complete response in one message response_content = llm_response.get("content", "") @@ -616,6 +668,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): if request.include_references: complete_response["references"] = references + # Add token usage if available + if token_tracker: + complete_response["token_usage"] = token_tracker.get_usage() + yield f"{json.dumps(complete_response)}\n" return StreamingResponse( @@ -1022,18 +1078,31 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): as structured data analysis typically requires source attribution. """ try: + # Reset token tracker at start of query if available + if token_tracker: + token_tracker.reset() + param = request.to_query_params(False) # No streaming for data endpoint response = await rag.aquery_data(request.query, param=param) + # Get token usage if available + token_usage = None + if token_tracker: + token_usage = token_tracker.get_usage() + # aquery_data returns the new format with status, message, data, and metadata if isinstance(response, dict): - return QueryDataResponse(**response) + response_dict = dict(response) + response_dict["token_usage"] = token_usage + return QueryDataResponse(**response_dict) else: # Handle unexpected response format return QueryDataResponse( status="failure", - message="Invalid response type", + message="Unexpected response format", data={}, + metadata={}, + token_usage=None, ) except Exception as e: trace_exception(e) diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 851a69de..05827dc9 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -870,7 +870,8 @@ class LightRAG: ids: str | list[str] | None = None, file_paths: str | list[str] | None = None, track_id: str | None = None, - ) -> str: + token_tracker: TokenTracker | None = None, + ) -> tuple[str, dict]: """Sync Insert documents with checkpoint support Args: @@ -884,10 +885,10 @@ class LightRAG: track_id: tracking ID for monitoring processing status, if not provided, will be generated Returns: - str: tracking ID for monitoring processing status + tuple[str, dict]: (tracking ID for monitoring processing status, token usage statistics) """ loop = always_get_an_event_loop() - return loop.run_until_complete( + result = loop.run_until_complete( self.ainsert( input, split_by_character, @@ -895,8 +896,10 @@ class LightRAG: ids, file_paths, track_id, + token_tracker, ) ) + return result async def ainsert( self, @@ -906,7 +909,8 @@ class LightRAG: ids: str | list[str] | None = None, file_paths: str | list[str] | None = None, track_id: str | None = None, - ) -> str: + token_tracker: TokenTracker | None = None, + ) -> tuple[str, dict]: """Async Insert documents with checkpoint support Args: @@ -930,9 +934,15 @@ class LightRAG: await self.apipeline_process_enqueue_documents( split_by_character, split_by_character_only, + token_tracker, ) - return track_id + return track_id, token_tracker.get_usage() if token_tracker else { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "call_count": 0, + } # TODO: deprecated, use insert instead def insert_custom_chunks( @@ -1107,7 +1117,9 @@ class LightRAG: "file_path" ], # Store file path in document status "track_id": track_id, # Store track_id in document status - "metadata": metadata[i] if isinstance(metadata, list) and i < len(metadata) else metadata, # added provided custom metadata per document + "metadata": metadata[i] + if isinstance(metadata, list) and i < len(metadata) + else metadata, # added provided custom metadata per document } for i, (id_, content_data) in enumerate(contents.items()) } @@ -1363,6 +1375,7 @@ class LightRAG: self, split_by_character: str | None = None, split_by_character_only: bool = False, + token_tracker: TokenTracker | None = None, ) -> None: """ Process pending documents by splitting them into chunks, processing @@ -1484,9 +1497,12 @@ class LightRAG: pipeline_status: dict, pipeline_status_lock: asyncio.Lock, semaphore: asyncio.Semaphore, + token_tracker: TokenTracker | None = None, ) -> None: """Process single document""" - doc_metadata = getattr(status_doc, "metadata", None) + doc_metadata = getattr(status_doc, "metadata", None) + if doc_metadata is None: + doc_metadata = {} file_extraction_stage_ok = False async with semaphore: nonlocal processed_count @@ -1498,7 +1514,6 @@ class LightRAG: # Get file path from status document file_path = getattr( status_doc, "file_path", "unknown_source" - ) async with pipeline_status_lock: @@ -1561,7 +1576,9 @@ class LightRAG: # Process document in two stages # Stage 1: Process text chunks and docs (parallel execution) - doc_metadata["processing_start_time"] = processing_start_time + doc_metadata["processing_start_time"] = ( + processing_start_time + ) doc_status_task = asyncio.create_task( self.doc_status.upsert( @@ -1610,6 +1627,7 @@ class LightRAG: doc_metadata, pipeline_status, pipeline_status_lock, + token_tracker, ) ) await entity_relation_task @@ -1643,7 +1661,9 @@ class LightRAG: processing_end_time = int(time.time()) # Update document status to failed - doc_metadata["processing_start_time"] = processing_start_time + doc_metadata["processing_start_time"] = ( + processing_start_time + ) doc_metadata["processing_end_time"] = processing_end_time await self.doc_status.upsert( { @@ -1691,7 +1711,9 @@ class LightRAG: doc_metadata["processing_start_time"] = ( processing_start_time ) - doc_metadata["processing_end_time"] = processing_end_time + doc_metadata["processing_end_time"] = ( + processing_end_time + ) await self.doc_status.upsert( { @@ -1748,7 +1770,9 @@ class LightRAG: doc_metadata["processing_start_time"] = ( processing_start_time ) - doc_metadata["processing_end_time"] = processing_end_time + doc_metadata["processing_end_time"] = ( + processing_end_time + ) await self.doc_status.upsert( { @@ -1778,6 +1802,7 @@ class LightRAG: pipeline_status, pipeline_status_lock, semaphore, + token_tracker, ) ) @@ -1827,6 +1852,7 @@ class LightRAG: metadata: dict | None, pipeline_status=None, pipeline_status_lock=None, + token_tracker: TokenTracker | None = None, ) -> list: try: chunk_results = await extract_entities( @@ -1837,6 +1863,7 @@ class LightRAG: pipeline_status_lock=pipeline_status_lock, llm_response_cache=self.llm_response_cache, text_chunks_storage=self.text_chunks, + token_tracker=token_tracker, ) return chunk_results except Exception as e: @@ -2061,14 +2088,18 @@ class LightRAG: self, query: str, param: QueryParam = QueryParam(), + token_tracker: TokenTracker | None = None, system_prompt: str | None = None, - ) -> str | Iterator[str]: + ) -> Any: """ - Perform a sync query. + User query interface (backward compatibility wrapper). + + Delegates to aquery() for asynchronous execution and returns the result. Args: query (str): The query to be executed. param (QueryParam): Configuration parameters for query execution. + token_tracker (TokenTracker | None): Optional token tracker for monitoring usage. prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"]. Returns: @@ -2076,14 +2107,17 @@ class LightRAG: """ loop = always_get_an_event_loop() - return loop.run_until_complete(self.aquery(query, param, system_prompt)) # type: ignore + return loop.run_until_complete( + self.aquery(query, param, token_tracker, system_prompt) + ) # type: ignore async def aquery( self, query: str, param: QueryParam = QueryParam(), + token_tracker: TokenTracker | None = None, system_prompt: str | None = None, - ) -> str | AsyncIterator[str]: + ) -> Any: """ Perform a async query (backward compatibility wrapper). @@ -2102,7 +2136,7 @@ class LightRAG: - Streaming: Returns AsyncIterator[str] """ # Call the new aquery_llm function to get complete results - result = await self.aquery_llm(query, param, system_prompt) + result = await self.aquery_llm(query, param, system_prompt, token_tracker) # Extract and return only the LLM response for backward compatibility llm_response = result.get("llm_response", {}) @@ -2137,6 +2171,7 @@ class LightRAG: self, query: str, param: QueryParam = QueryParam(), + token_tracker: TokenTracker | None = None, ) -> dict[str, Any]: """ Asynchronous data retrieval API: returns structured retrieval results without LLM generation. @@ -2330,6 +2365,7 @@ class LightRAG: query: str, param: QueryParam = QueryParam(), system_prompt: str | None = None, + token_tracker: TokenTracker | None = None, ) -> dict[str, Any]: """ Asynchronous complete query API: returns structured retrieval results with LLM generation. @@ -2364,6 +2400,7 @@ class LightRAG: hashing_kv=self.llm_response_cache, system_prompt=system_prompt, chunks_vdb=self.chunks_vdb, + token_tracker=token_tracker, ) elif param.mode == "naive": query_result = await naive_query( @@ -2373,6 +2410,7 @@ class LightRAG: global_config, hashing_kv=self.llm_response_cache, system_prompt=system_prompt, + token_tracker=token_tracker, ) elif param.mode == "bypass": # Bypass mode: directly use LLM without knowledge retrieval diff --git a/lightrag/llm/azure_openai.py b/lightrag/llm/azure_openai.py index 824ff088..e2793238 100644 --- a/lightrag/llm/azure_openai.py +++ b/lightrag/llm/azure_openai.py @@ -46,6 +46,7 @@ async def azure_openai_complete_if_cache( base_url: str | None = None, api_key: str | None = None, api_version: str | None = None, + token_tracker: Any | None = None, **kwargs, ): if enable_cot: @@ -94,28 +95,73 @@ async def azure_openai_complete_if_cache( ) if hasattr(response, "__aiter__"): + final_chunk_usage = None + accumulated_response = "" async def inner(): - async for chunk in response: - if len(chunk.choices) == 0: - continue - content = chunk.choices[0].delta.content - if content is None: - continue - if r"\u" in content: - content = safe_unicode_decode(content.encode("utf-8")) - yield content + nonlocal final_chunk_usage, accumulated_response + try: + async for chunk in response: + if len(chunk.choices) == 0: + continue + content = chunk.choices[0].delta.content + if content is None: + continue + accumulated_response += content + if r"\u" in content: + content = safe_unicode_decode(content.encode("utf-8")) + yield content + + # Check for usage in the last chunk + if hasattr(chunk, "usage") and chunk.usage is not None: + final_chunk_usage = chunk.usage + except Exception as e: + logger.error(f"Error in Azure OpenAI stream response: {str(e)}") + raise + finally: + # After streaming is complete, track token usage + if token_tracker and final_chunk_usage: + # Use actual usage from the API + token_counts = { + "prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0), + "completion_tokens": getattr( + final_chunk_usage, "completion_tokens", 0 + ), + "total_tokens": getattr(final_chunk_usage, "total_tokens", 0), + } + token_tracker.add_usage(token_counts) + logger.debug(f"Azure OpenAI streaming token usage: {token_counts}") + elif token_tracker: + logger.debug( + "No usage information available in Azure OpenAI streaming response" + ) return inner() else: content = response.choices[0].message.content if r"\u" in content: content = safe_unicode_decode(content.encode("utf-8")) + + # Track token usage for non-streaming response + if token_tracker and hasattr(response, "usage"): + token_counts = { + "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), + "completion_tokens": getattr(response.usage, "completion_tokens", 0), + "total_tokens": getattr(response.usage, "total_tokens", 0), + } + token_tracker.add_usage(token_counts) + logger.debug(f"Azure OpenAI non-streaming token usage: {token_counts}") + return content async def azure_openai_complete( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs + prompt, + system_prompt=None, + history_messages=[], + keyword_extraction=False, + token_tracker=None, + **kwargs, ) -> str: kwargs.pop("keyword_extraction", None) result = await azure_openai_complete_if_cache( @@ -123,6 +169,7 @@ async def azure_openai_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + token_tracker=token_tracker, **kwargs, ) return result @@ -142,6 +189,7 @@ async def azure_openai_embed( base_url: str | None = None, api_key: str | None = None, api_version: str | None = None, + token_tracker: Any | None = None, ) -> np.ndarray: deployment = ( os.getenv("AZURE_EMBEDDING_DEPLOYMENT") @@ -174,4 +222,14 @@ async def azure_openai_embed( response = await openai_async_client.embeddings.create( model=model, input=texts, encoding_format="float" ) + + # Track token usage for embeddings if token tracker is provided + if token_tracker and hasattr(response, "usage"): + token_counts = { + "prompt_tokens": getattr(response.usage, "prompt_tokens", 0), + "total_tokens": getattr(response.usage, "total_tokens", 0), + } + token_tracker.add_usage(token_counts) + logger.debug(f"Azure OpenAI embedding token usage: {token_counts}") + return np.array([dp.embedding for dp in response.data]) diff --git a/lightrag/llm/bedrock.py b/lightrag/llm/bedrock.py index 16737341..f3f04cfd 100644 --- a/lightrag/llm/bedrock.py +++ b/lightrag/llm/bedrock.py @@ -48,6 +48,7 @@ async def bedrock_complete_if_cache( aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, + token_tracker=None, **kwargs, ) -> Union[str, AsyncIterator[str]]: if enable_cot: @@ -155,6 +156,18 @@ async def bedrock_complete_if_cache( yield text # Handle other event types that might indicate stream end elif "messageStop" in event: + # Track token usage for streaming if token tracker is provided + if token_tracker and "usage" in event: + usage = event["usage"] + token_counts = { + "prompt_tokens": usage.get("inputTokens", 0), + "completion_tokens": usage.get("outputTokens", 0), + "total_tokens": usage.get("totalTokens", 0), + } + token_tracker.add_usage(token_counts) + logging.debug( + f"Bedrock streaming token usage: {token_counts}" + ) break except Exception as e: @@ -228,6 +241,17 @@ async def bedrock_complete_if_cache( if not content or content.strip() == "": raise BedrockError("Received empty content from Bedrock API") + # Track token usage for non-streaming if token tracker is provided + if token_tracker and "usage" in response: + usage = response["usage"] + token_counts = { + "prompt_tokens": usage.get("inputTokens", 0), + "completion_tokens": usage.get("outputTokens", 0), + "total_tokens": usage.get("totalTokens", 0), + } + token_tracker.add_usage(token_counts) + logging.debug(f"Bedrock non-streaming token usage: {token_counts}") + return content except Exception as e: @@ -239,7 +263,12 @@ async def bedrock_complete_if_cache( # Generic Bedrock completion function async def bedrock_complete( - prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs + prompt, + system_prompt=None, + history_messages=[], + keyword_extraction=False, + token_tracker=None, + **kwargs, ) -> Union[str, AsyncIterator[str]]: kwargs.pop("keyword_extraction", None) model_name = kwargs["hashing_kv"].global_config["llm_model_name"] @@ -248,6 +277,7 @@ async def bedrock_complete( prompt, system_prompt=system_prompt, history_messages=history_messages, + token_tracker=token_tracker, **kwargs, ) return result @@ -265,6 +295,7 @@ async def bedrock_embed( aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, + token_tracker=None, ) -> np.ndarray: # Respect existing env; only set if a non-empty value is available access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id diff --git a/lightrag/llm/lollms.py b/lightrag/llm/lollms.py index 9274dbfc..2fe2561f 100644 --- a/lightrag/llm/lollms.py +++ b/lightrag/llm/lollms.py @@ -108,6 +108,7 @@ async def lollms_model_complete( history_messages=[], enable_cot: bool = False, keyword_extraction=False, + token_tracker=None, **kwargs, ) -> Union[str, AsyncIterator[str]]: """Complete function for lollms model generation.""" @@ -135,7 +136,11 @@ async def lollms_model_complete( async def lollms_embed( - texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs + texts: List[str], + embed_model=None, + base_url="http://localhost:9600", + token_tracker=None, + **kwargs, ) -> np.ndarray: """ Generate embeddings for a list of texts using lollms server. @@ -144,6 +149,7 @@ async def lollms_embed( texts: List of strings to embed embed_model: Model name (not used directly as lollms uses configured vectorizer) base_url: URL of the lollms server + token_tracker: Optional token usage tracker for monitoring API usage **kwargs: Additional arguments passed to the request Returns: diff --git a/lightrag/llm/ollama.py b/lightrag/llm/ollama.py index b013496e..01f5e06c 100644 --- a/lightrag/llm/ollama.py +++ b/lightrag/llm/ollama.py @@ -39,6 +39,7 @@ async def _ollama_model_if_cache( system_prompt=None, history_messages=[], enable_cot: bool = False, + token_tracker=None, **kwargs, ) -> Union[str, AsyncIterator[str]]: if enable_cot: @@ -74,13 +75,47 @@ async def _ollama_model_if_cache( """cannot cache stream response and process reasoning""" async def inner(): + accumulated_response = "" try: async for chunk in response: - yield chunk["message"]["content"] + chunk_content = chunk["message"]["content"] + accumulated_response += chunk_content + yield chunk_content except Exception as e: logger.error(f"Error in stream response: {str(e)}") raise finally: + # Track token usage for streaming if token tracker is provided + if token_tracker: + # Estimate prompt tokens: roughly 4 characters per token for English text + prompt_text = "" + if system_prompt: + prompt_text += system_prompt + " " + prompt_text += ( + " ".join( + [msg.get("content", "") for msg in history_messages] + ) + + " " + ) + prompt_text += prompt + prompt_tokens = len(prompt_text) // 4 + ( + 1 if len(prompt_text) % 4 else 0 + ) + + # Estimate completion tokens from accumulated response + completion_tokens = len(accumulated_response) // 4 + ( + 1 if len(accumulated_response) % 4 else 0 + ) + total_tokens = prompt_tokens + completion_tokens + + token_tracker.add_usage( + { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + ) + try: await ollama_client._client.aclose() logger.debug("Successfully closed Ollama client for streaming") @@ -91,6 +126,35 @@ async def _ollama_model_if_cache( else: model_response = response["message"]["content"] + # Track token usage if token tracker is provided + # Note: Ollama doesn't provide token usage in chat responses, so we estimate + if token_tracker: + # Estimate prompt tokens: roughly 4 characters per token for English text + prompt_text = "" + if system_prompt: + prompt_text += system_prompt + " " + prompt_text += ( + " ".join([msg.get("content", "") for msg in history_messages]) + " " + ) + prompt_text += prompt + prompt_tokens = len(prompt_text) // 4 + ( + 1 if len(prompt_text) % 4 else 0 + ) + + # Estimate completion tokens from response + completion_tokens = len(model_response) // 4 + ( + 1 if len(model_response) % 4 else 0 + ) + total_tokens = prompt_tokens + completion_tokens + + token_tracker.add_usage( + { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + } + ) + """ If the model also wraps its thoughts in a specific tag, this information is not needed for the final @@ -126,6 +190,7 @@ async def ollama_model_complete( history_messages=[], enable_cot: bool = False, keyword_extraction=False, + token_tracker=None, **kwargs, ) -> Union[str, AsyncIterator[str]]: keyword_extraction = kwargs.pop("keyword_extraction", None) @@ -138,11 +203,14 @@ async def ollama_model_complete( system_prompt=system_prompt, history_messages=history_messages, enable_cot=enable_cot, + token_tracker=token_tracker, **kwargs, ) -async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: +async def ollama_embed( + texts: list[str], embed_model, token_tracker=None, **kwargs +) -> np.ndarray: api_key = kwargs.pop("api_key", None) headers = { "Content-Type": "application/json", @@ -160,6 +228,21 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: data = await ollama_client.embed( model=embed_model, input=texts, options=options ) + + # Track token usage if token tracker is provided + # Note: Ollama doesn't provide token usage in embedding responses, so we estimate + if token_tracker: + # Estimate tokens: roughly 4 characters per token for English text + total_chars = sum(len(text) for text in texts) + estimated_tokens = total_chars // 4 + (1 if total_chars % 4 else 0) + token_tracker.add_usage( + { + "prompt_tokens": estimated_tokens, + "completion_tokens": 0, + "total_tokens": estimated_tokens, + } + ) + return np.array(data["embeddings"]) except Exception as e: logger.error(f"Error in ollama_embed: {str(e)}") diff --git a/lightrag/operate.py b/lightrag/operate.py index 8ba92181..13262d82 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -12,6 +12,7 @@ from .utils import ( logger, compute_mdhash_id, Tokenizer, + TokenTracker, is_float_regex, sanitize_and_normalize_extracted_text, pack_user_ass_to_openai_messages, @@ -126,6 +127,7 @@ async def _handle_entity_relation_summary( seperator: str, global_config: dict, llm_response_cache: BaseKVStorage | None = None, + token_tracker: TokenTracker | None = None, ) -> tuple[str, bool]: """Handle entity relation description summary using map-reduce approach. @@ -188,6 +190,7 @@ async def _handle_entity_relation_summary( current_list, global_config, llm_response_cache, + token_tracker, ) return final_summary, True # LLM was used for final summarization @@ -243,6 +246,7 @@ async def _handle_entity_relation_summary( chunk, global_config, llm_response_cache, + token_tracker, ) new_summaries.append(summary) llm_was_used = True # Mark that LLM was used in reduce phase @@ -257,6 +261,7 @@ async def _summarize_descriptions( description_list: list[str], global_config: dict, llm_response_cache: BaseKVStorage | None = None, + token_tracker: TokenTracker | None = None, ) -> str: """Helper function to summarize a list of descriptions using LLM. @@ -312,9 +317,10 @@ async def _summarize_descriptions( # Use LLM function with cache (higher priority for summary generation) summary, _ = await use_llm_func_with_cache( use_prompt, - use_llm_func, - llm_response_cache=llm_response_cache, + use_llm_func=use_llm_func, + hashing_kv=llm_response_cache, cache_type="summary", + token_tracker=token_tracker, ) return summary @@ -405,7 +411,7 @@ async def _handle_single_relationship_extraction( ): # treat "relationship" and "relation" interchangeable if len(record_attributes) > 1 and "relation" in record_attributes[0]: logger.warning( - f"{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on REALTION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) >2 else 'N/A'}`" + f"{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on REALTION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) > 2 else 'N/A'}`" ) logger.debug(record_attributes) return None @@ -463,7 +469,6 @@ async def _handle_single_relationship_extraction( file_path=file_path, timestamp=timestamp, metadata=metadata, - ) except ValueError as e: @@ -2037,6 +2042,7 @@ async def extract_entities( pipeline_status_lock=None, llm_response_cache: BaseKVStorage | None = None, text_chunks_storage: BaseKVStorage | None = None, + token_tracker: TokenTracker | None = None, ) -> list: use_llm_func: callable = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] @@ -2150,12 +2156,13 @@ async def extract_entities( final_result, timestamp = await use_llm_func_with_cache( entity_extraction_user_prompt, - use_llm_func, + use_llm_func=use_llm_func, system_prompt=entity_extraction_system_prompt, - llm_response_cache=llm_response_cache, + hashing_kv=llm_response_cache, cache_type="extract", chunk_id=chunk_key, cache_keys_collector=cache_keys_collector, + token_tracker=token_tracker, ) history = pack_user_ass_to_openai_messages( @@ -2177,16 +2184,16 @@ async def extract_entities( if entity_extract_max_gleaning > 0: glean_result, timestamp = await use_llm_func_with_cache( entity_continue_extraction_user_prompt, - use_llm_func, + use_llm_func=use_llm_func, system_prompt=entity_extraction_system_prompt, llm_response_cache=llm_response_cache, history_messages=history, cache_type="extract", chunk_id=chunk_key, cache_keys_collector=cache_keys_collector, + token_tracker=token_tracker, ) - # Process gleaning result separately with file path and metadata glean_nodes, glean_edges = await _process_extraction_result( glean_result, @@ -2300,7 +2307,7 @@ async def extract_entities( await asyncio.wait(pending) # Add progress prefix to the exception message - progress_prefix = f"C[{processed_chunks+1}/{total_chunks}]" + progress_prefix = f"C[{processed_chunks + 1}/{total_chunks}]" # Re-raise the original exception with a prefix prefixed_exception = create_prefixed_exception(first_exception, progress_prefix) @@ -2324,6 +2331,7 @@ async def kg_query( system_prompt: str | None = None, chunks_vdb: BaseVectorStorage = None, return_raw_data: Literal[True] = False, + token_tracker: TokenTracker | None = None, ) -> dict[str, Any]: ... @@ -2341,6 +2349,7 @@ async def kg_query( chunks_vdb: BaseVectorStorage = None, metadata_filters: list | None = None, return_raw_data: Literal[False] = False, + token_tracker: TokenTracker | None = None, ) -> str | AsyncIterator[str]: ... @@ -2355,6 +2364,7 @@ async def kg_query( hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, chunks_vdb: BaseVectorStorage = None, + token_tracker: TokenTracker | None = None, ) -> QueryResult: """ Execute knowledge graph query and return unified QueryResult object. @@ -2422,7 +2432,7 @@ async def kg_query( return QueryResult(content=cached_response) hl_keywords, ll_keywords = await get_keywords_from_query( - query, query_param, global_config, hashing_kv + query, query_param, global_config, hashing_kv, token_tracker ) logger.debug(f"High-level keywords: {hl_keywords}") @@ -2526,6 +2536,7 @@ async def kg_query( history_messages=query_param.conversation_history, enable_cot=True, stream=query_param.stream, + token_tracker=token_tracker, ) if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): @@ -2583,6 +2594,7 @@ async def get_keywords_from_query( query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, + token_tracker: TokenTracker | None = None, ) -> tuple[list[str], list[str]]: """ Retrieves high-level and low-level keywords for RAG operations. @@ -2605,7 +2617,7 @@ async def get_keywords_from_query( # Extract keywords using extract_keywords_only function which already supports conversation history hl_keywords, ll_keywords = await extract_keywords_only( - query, query_param, global_config, hashing_kv + query, query_param, global_config, hashing_kv, token_tracker ) return hl_keywords, ll_keywords @@ -2615,6 +2627,7 @@ async def extract_keywords_only( param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, + token_tracker: TokenTracker | None = None, ) -> tuple[list[str], list[str]]: """ Extract high-level and low-level keywords from the given 'text' using the LLM. @@ -2668,7 +2681,9 @@ async def extract_keywords_only( # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) - result = await use_model_func(kw_prompt, keyword_extraction=True) + result = await use_model_func( + kw_prompt, keyword_extraction=True, token_tracker=token_tracker + ) # 5. Parse out JSON from the LLM response result = remove_think_tags(result) @@ -2746,7 +2761,10 @@ async def _get_vector_context( cosine_threshold = chunks_vdb.cosine_better_than_threshold results = await chunks_vdb.query( - query, top_k=search_top_k, query_embedding=query_embedding, metadata_filter=query_param.metadata_filter + query, + top_k=search_top_k, + query_embedding=query_embedding, + metadata_filter=query_param.metadata_filter, ) if not results: logger.info( @@ -2763,7 +2781,7 @@ async def _get_vector_context( "file_path": result.get("file_path", "unknown_source"), "source_type": "vector", # Mark the source type "chunk_id": result.get("id"), # Add chunk_id for deduplication - "metadata": result.get("metadata") + "metadata": result.get("metadata"), } valid_chunks.append(chunk_with_metadata) @@ -3529,8 +3547,9 @@ async def _get_node_data( f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})" ) - - results = await entities_vdb.query(query, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter) + results = await entities_vdb.query( + query, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter + ) if not len(results): return [], [] @@ -3538,7 +3557,6 @@ async def _get_node_data( # Extract all entity IDs from your results list node_ids = [r["entity_name"] for r in results] - # Extract all entity IDs from your results list node_ids = [r["entity_name"] for r in results] @@ -3810,7 +3828,9 @@ async def _get_edge_data( f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})" ) - results = await relationships_vdb.query(keywords, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter) + results = await relationships_vdb.query( + keywords, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter + ) if not len(results): return [], [] @@ -4104,6 +4124,7 @@ async def naive_query( hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, return_raw_data: Literal[True] = True, + token_tracker: TokenTracker | None = None, ) -> dict[str, Any]: ... @@ -4116,6 +4137,7 @@ async def naive_query( hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, return_raw_data: Literal[False] = False, + token_tracker: TokenTracker | None = None, ) -> str | AsyncIterator[str]: ... @@ -4126,6 +4148,7 @@ async def naive_query( global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, + token_tracker: TokenTracker | None = None, ) -> QueryResult: """ Execute naive query and return unified QueryResult object. @@ -4321,6 +4344,7 @@ async def naive_query( history_messages=query_param.conversation_history, enable_cot=True, stream=query_param.stream, + token_tracker=token_tracker, ) if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): diff --git a/lightrag/utils.py b/lightrag/utils.py index 60542e43..0a459b34 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -4,6 +4,7 @@ import weakref import asyncio import html import csv +import contextvars import json import logging import logging.handlers @@ -507,6 +508,7 @@ def priority_limit_async_func_call( task_id, args, kwargs, + ctx, ) = await asyncio.wait_for(queue.get(), timeout=1.0) except asyncio.TimeoutError: continue @@ -536,11 +538,15 @@ def priority_limit_async_func_call( try: # Execute function with timeout protection if max_execution_timeout is not None: + # Run the function in the captured context + task = ctx.run(lambda: asyncio.create_task(func(*args, **kwargs))) result = await asyncio.wait_for( - func(*args, **kwargs), timeout=max_execution_timeout + task, timeout=max_execution_timeout ) else: - result = await func(*args, **kwargs) + # Run the function in the captured context + task = ctx.run(lambda: asyncio.create_task(func(*args, **kwargs))) + result = await task # Set result if future is still valid if not task_state.future.done(): @@ -791,6 +797,9 @@ def priority_limit_async_func_call( future=future, start_time=asyncio.get_event_loop().time() ) + # Capture current context + ctx = contextvars.copy_context() + try: # Register task state async with task_states_lock: @@ -809,13 +818,13 @@ def priority_limit_async_func_call( if _queue_timeout is not None: await asyncio.wait_for( queue.put( - (_priority, current_count, task_id, args, kwargs) + (_priority, current_count, task_id, args, kwargs, ctx) ), timeout=_queue_timeout, ) else: await queue.put( - (_priority, current_count, task_id, args, kwargs) + (_priority, current_count, task_id, args, kwargs, ctx) ) except asyncio.TimeoutError: raise QueueFullError( @@ -1472,8 +1481,7 @@ async def aexport_data( else: raise ValueError( - f"Unsupported file format: {file_format}. " - f"Choose from: csv, excel, md, txt" + f"Unsupported file format: {file_format}. Choose from: csv, excel, md, txt" ) if file_format is not None: print(f"Data exported to: {output_path} with format: {file_format}") @@ -1601,6 +1609,8 @@ async def use_llm_func_with_cache( cache_type: str = "extract", chunk_id: str | None = None, cache_keys_collector: list = None, + hashing_kv: "BaseKVStorage | None" = None, + token_tracker=None, ) -> tuple[str, int]: """Call LLM function with cache support and text sanitization @@ -1685,7 +1695,10 @@ async def use_llm_func_with_cache( kwargs["max_tokens"] = max_tokens res: str = await use_llm_func( - safe_user_prompt, system_prompt=safe_system_prompt, **kwargs + safe_user_prompt, + system_prompt=safe_system_prompt, + token_tracker=token_tracker, + **kwargs, ) res = remove_think_tags(res) @@ -1720,7 +1733,10 @@ async def use_llm_func_with_cache( try: res = await use_llm_func( - safe_user_prompt, system_prompt=safe_system_prompt, **kwargs + safe_user_prompt, + system_prompt=safe_system_prompt, + token_tracker=token_tracker, + **kwargs, ) except Exception as e: # Add [LLM func] prefix to error message @@ -2216,52 +2232,74 @@ async def pick_by_vector_similarity( return all_chunk_ids[:num_of_chunks] +from contextvars import ContextVar + + class TokenTracker: - """Track token usage for LLM calls.""" + """Track token usage for LLM calls using ContextVars for concurrency support.""" + + _usage_var: ContextVar[dict] = ContextVar("token_usage", default=None) def __init__(self): - self.reset() + # No instance state needed as we use ContextVar + pass def __enter__(self): self.reset() return self def __exit__(self, exc_type, exc_val, exc_tb): - print(self) + # Optional: Log usage on exit if needed + pass def reset(self): - self.prompt_tokens = 0 - self.completion_tokens = 0 - self.total_tokens = 0 - self.call_count = 0 + """Initialize/Reset token usage for the current context.""" + self._usage_var.set( + { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "call_count": 0, + } + ) - def add_usage(self, token_counts): + def _get_current_usage(self) -> dict: + """Get the usage dict for the current context, initializing if necessary.""" + usage = self._usage_var.get() + if usage is None: + usage = { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "call_count": 0, + } + self._usage_var.set(usage) + return usage + + def add_usage(self, token_counts: dict): """Add token usage from one LLM call. Args: token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens """ - self.prompt_tokens += token_counts.get("prompt_tokens", 0) - self.completion_tokens += token_counts.get("completion_tokens", 0) + usage = self._get_current_usage() + + usage["prompt_tokens"] += token_counts.get("prompt_tokens", 0) + usage["completion_tokens"] += token_counts.get("completion_tokens", 0) # If total_tokens is provided, use it directly; otherwise calculate the sum if "total_tokens" in token_counts: - self.total_tokens += token_counts["total_tokens"] + usage["total_tokens"] += token_counts["total_tokens"] else: - self.total_tokens += token_counts.get( + usage["total_tokens"] += token_counts.get( "prompt_tokens", 0 ) + token_counts.get("completion_tokens", 0) - self.call_count += 1 + usage["call_count"] += 1 def get_usage(self): """Get current usage statistics.""" - return { - "prompt_tokens": self.prompt_tokens, - "completion_tokens": self.completion_tokens, - "total_tokens": self.total_tokens, - "call_count": self.call_count, - } + return self._get_current_usage().copy() def __str__(self): usage = self.get_usage() @@ -2273,6 +2311,26 @@ class TokenTracker: ) +def estimate_embedding_tokens(texts: list[str], tokenizer: Tokenizer) -> int: + """Estimate tokens for embedding operations based on text length. + + Most embedding APIs don't return token counts, so we estimate based on + the tokenizer encoding. This provides a reasonable approximation for tracking. + + Args: + texts: List of text strings to be embedded + tokenizer: Tokenizer instance for encoding + + Returns: + Estimated total token count for all texts + """ + total = 0 + for text in texts: + if text: # Skip empty strings + total += len(tokenizer.encode(text)) + return total + + async def apply_rerank_if_enabled( query: str, retrieved_docs: list[dict],