feat (token_tracking): added tracking token to both query and insert endpoints --and consequently pipeline

This commit is contained in:
GGrassia 2025-11-26 17:00:04 +01:00
parent cd664de057
commit 3a2d3ddb9f
11 changed files with 588 additions and 104 deletions

View file

@ -397,6 +397,15 @@ def parse_args() -> argparse.Namespace:
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int "EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
) )
# Token tracking configuration
parser.add_argument(
"--enable-token-tracking",
action="store_true",
default=get_env_value("ENABLE_TOKEN_TRACKING", False, bool),
help="Enable token usage tracking for LLM calls (default: from env or False)",
)
args.enable_token_tracking = get_env_value("ENABLE_TOKEN_TRACKING", False, bool)
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag

View file

@ -301,7 +301,10 @@ def create_app(args):
Path(args.working_dir).mkdir(parents=True, exist_ok=True) Path(args.working_dir).mkdir(parents=True, exist_ok=True)
def create_optimized_openai_llm_func( def create_optimized_openai_llm_func(
config_cache: LLMConfigCache, args, llm_timeout: int config_cache: LLMConfigCache,
args,
llm_timeout: int,
enable_token_tracking=False,
): ):
"""Create optimized OpenAI LLM function with pre-processed configuration""" """Create optimized OpenAI LLM function with pre-processed configuration"""
@ -332,13 +335,19 @@ def create_app(args):
history_messages=history_messages, history_messages=history_messages,
base_url=args.llm_binding_host, base_url=args.llm_binding_host,
api_key=args.llm_binding_api_key, api_key=args.llm_binding_api_key,
token_tracker=getattr(app.state, "token_tracker", None)
if enable_token_tracking and hasattr(app.state, "token_tracker")
else None,
**kwargs, **kwargs,
) )
return optimized_openai_alike_model_complete return optimized_openai_alike_model_complete
def create_optimized_azure_openai_llm_func( def create_optimized_azure_openai_llm_func(
config_cache: LLMConfigCache, args, llm_timeout: int config_cache: LLMConfigCache,
args,
llm_timeout: int,
enable_token_tracking=False,
): ):
"""Create optimized Azure OpenAI LLM function with pre-processed configuration""" """Create optimized Azure OpenAI LLM function with pre-processed configuration"""
@ -359,8 +368,8 @@ def create_app(args):
# Use pre-processed configuration to avoid repeated parsing # Use pre-processed configuration to avoid repeated parsing
kwargs["timeout"] = llm_timeout kwargs["timeout"] = llm_timeout
if config_cache.openai_llm_options: if config_cache.azure_openai_llm_options:
kwargs.update(config_cache.openai_llm_options) kwargs.update(config_cache.azure_openai_llm_options)
return await azure_openai_complete_if_cache( return await azure_openai_complete_if_cache(
args.llm_model, args.llm_model,
@ -370,12 +379,15 @@ def create_app(args):
base_url=args.llm_binding_host, base_url=args.llm_binding_host,
api_key=os.getenv("AZURE_OPENAI_API_KEY", args.llm_binding_api_key), api_key=os.getenv("AZURE_OPENAI_API_KEY", args.llm_binding_api_key),
api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"), api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"),
token_tracker=getattr(app.state, "token_tracker", None)
if enable_token_tracking and hasattr(app.state, "token_tracker")
else None,
**kwargs, **kwargs,
) )
return optimized_azure_openai_model_complete return optimized_azure_openai_model_complete
def create_llm_model_func(binding: str): def create_llm_model_func(binding: str, enable_token_tracking=False):
""" """
Create LLM model function based on binding type. Create LLM model function based on binding type.
Uses optimized functions for OpenAI bindings and lazy import for others. Uses optimized functions for OpenAI bindings and lazy import for others.
@ -384,21 +396,42 @@ def create_app(args):
if binding == "lollms": if binding == "lollms":
from lightrag.llm.lollms import lollms_model_complete from lightrag.llm.lollms import lollms_model_complete
return lollms_model_complete async def lollms_model_complete_with_tracker(*args, **kwargs):
# Add token tracker if enabled
if enable_token_tracking and hasattr(app.state, "token_tracker"):
kwargs["token_tracker"] = app.state.token_tracker
return await lollms_model_complete(*args, **kwargs)
return lollms_model_complete_with_tracker
elif binding == "ollama": elif binding == "ollama":
from lightrag.llm.ollama import ollama_model_complete from lightrag.llm.ollama import ollama_model_complete
return ollama_model_complete async def ollama_model_complete_with_tracker(*args, **kwargs):
# Add token tracker if enabled
if enable_token_tracking and hasattr(app.state, "token_tracker"):
kwargs["token_tracker"] = app.state.token_tracker
return await ollama_model_complete(*args, **kwargs)
return ollama_model_complete_with_tracker
elif binding == "aws_bedrock": elif binding == "aws_bedrock":
return bedrock_model_complete # Already defined locally
async def bedrock_model_complete_with_tracker(*args, **kwargs):
# Add token tracker if enabled
if enable_token_tracking and hasattr(app.state, "token_tracker"):
kwargs["token_tracker"] = app.state.token_tracker
return await bedrock_model_complete(*args, **kwargs)
return bedrock_model_complete_with_tracker
elif binding == "azure_openai": elif binding == "azure_openai":
# Use optimized function with pre-processed configuration # Use optimized function with pre-processed configuration
return create_optimized_azure_openai_llm_func( return create_optimized_azure_openai_llm_func(
config_cache, args, llm_timeout config_cache, args, llm_timeout, enable_token_tracking
) )
else: # openai and compatible else: # openai and compatible
# Use optimized function with pre-processed configuration # Use optimized function with pre-processed configuration
return create_optimized_openai_llm_func(config_cache, args, llm_timeout) return create_optimized_openai_llm_func(
config_cache, args, llm_timeout, enable_token_tracking
)
except ImportError as e: except ImportError as e:
raise Exception(f"Failed to import {binding} LLM binding: {e}") raise Exception(f"Failed to import {binding} LLM binding: {e}")
@ -422,7 +455,14 @@ def create_app(args):
return {} return {}
def create_optimized_embedding_function( def create_optimized_embedding_function(
config_cache: LLMConfigCache, binding, model, host, api_key, dimensions, args config_cache: LLMConfigCache,
binding,
model,
host,
api_key,
dimensions,
args,
enable_token_tracking=False,
): ):
""" """
Create optimized embedding function with pre-processed configuration for applicable bindings. Create optimized embedding function with pre-processed configuration for applicable bindings.
@ -435,7 +475,13 @@ def create_app(args):
from lightrag.llm.lollms import lollms_embed from lightrag.llm.lollms import lollms_embed
return await lollms_embed( return await lollms_embed(
texts, embed_model=model, host=host, api_key=api_key texts,
embed_model=model,
host=host,
api_key=api_key,
token_tracker=getattr(app.state, "token_tracker", None)
if enable_token_tracking and hasattr(app.state, "token_tracker")
else None,
) )
elif binding == "ollama": elif binding == "ollama":
from lightrag.llm.ollama import ollama_embed from lightrag.llm.ollama import ollama_embed
@ -455,26 +501,54 @@ def create_app(args):
host=host, host=host,
api_key=api_key, api_key=api_key,
options=ollama_options, options=ollama_options,
token_tracker=getattr(app.state, "token_tracker", None)
if enable_token_tracking and hasattr(app.state, "token_tracker")
else None,
) )
elif binding == "azure_openai": elif binding == "azure_openai":
from lightrag.llm.azure_openai import azure_openai_embed from lightrag.llm.azure_openai import azure_openai_embed
return await azure_openai_embed(texts, model=model, api_key=api_key) return await azure_openai_embed(
texts,
model=model,
api_key=api_key,
token_tracker=getattr(app.state, "token_tracker", None)
if enable_token_tracking and hasattr(app.state, "token_tracker")
else None,
)
elif binding == "aws_bedrock": elif binding == "aws_bedrock":
from lightrag.llm.bedrock import bedrock_embed from lightrag.llm.bedrock import bedrock_embed
return await bedrock_embed(texts, model=model) return await bedrock_embed(
texts,
model=model,
token_tracker=getattr(app.state, "token_tracker", None)
if enable_token_tracking and hasattr(app.state, "token_tracker")
else None,
)
elif binding == "jina": elif binding == "jina":
from lightrag.llm.jina import jina_embed from lightrag.llm.jina import jina_embed
return await jina_embed( return await jina_embed(
texts, dimensions=dimensions, base_url=host, api_key=api_key texts,
dimensions=dimensions,
base_url=host,
api_key=api_key,
token_tracker=getattr(app.state, "token_tracker", None)
if enable_token_tracking and hasattr(app.state, "token_tracker")
else None,
) )
else: # openai and compatible else: # openai and compatible
from lightrag.llm.openai import openai_embed from lightrag.llm.openai import openai_embed
return await openai_embed( return await openai_embed(
texts, model=model, base_url=host, api_key=api_key texts,
model=model,
base_url=host,
api_key=api_key,
token_tracker=getattr(app.state, "token_tracker", None)
if enable_token_tracking and hasattr(app.state, "token_tracker")
else None,
) )
except ImportError as e: except ImportError as e:
raise Exception(f"Failed to import {binding} embedding: {e}") raise Exception(f"Failed to import {binding} embedding: {e}")
@ -594,7 +668,9 @@ def create_app(args):
rag = LightRAG( rag = LightRAG(
working_dir=args.working_dir, working_dir=args.working_dir,
workspace=args.workspace, workspace=args.workspace,
llm_model_func=create_llm_model_func(args.llm_binding), llm_model_func=create_llm_model_func(
args.llm_binding, args.enable_token_tracking
),
llm_model_name=args.llm_model, llm_model_name=args.llm_model,
llm_model_max_async=args.max_async, llm_model_max_async=args.max_async,
summary_max_tokens=args.summary_max_tokens, summary_max_tokens=args.summary_max_tokens,
@ -604,7 +680,16 @@ def create_app(args):
llm_model_kwargs=create_llm_model_kwargs( llm_model_kwargs=create_llm_model_kwargs(
args.llm_binding, args, llm_timeout args.llm_binding, args, llm_timeout
), ),
embedding_func=embedding_func, embedding_func=create_optimized_embedding_function(
config_cache,
args.embedding_binding,
args.embedding_model,
args.embedding_binding_host,
args.embedding_binding_api_key,
args.embedding_dim,
args,
args.enable_token_tracking,
),
default_llm_timeout=llm_timeout, default_llm_timeout=llm_timeout,
default_embedding_timeout=embedding_timeout, default_embedding_timeout=embedding_timeout,
kv_storage=args.kv_storage, kv_storage=args.kv_storage,
@ -629,6 +714,17 @@ def create_app(args):
logger.error(f"Failed to initialize LightRAG: {e}") logger.error(f"Failed to initialize LightRAG: {e}")
raise raise
# Initialize token tracking if enabled
token_tracker = None
if args.enable_token_tracking:
from lightrag.utils import TokenTracker
token_tracker = TokenTracker()
logger.info("Token tracking enabled")
# Add token tracker to the app state for use in endpoints
app.state.token_tracker = token_tracker
# Add routes # Add routes
app.include_router( app.include_router(
create_document_routes( create_document_routes(
@ -637,7 +733,7 @@ def create_app(args):
api_key, api_key,
) )
) )
app.include_router(create_query_routes(rag, api_key, args.top_k)) app.include_router(create_query_routes(rag, api_key, args.top_k, token_tracker))
app.include_router(create_graph_routes(rag, api_key)) app.include_router(create_graph_routes(rag, api_key))
# Add Ollama API routes # Add Ollama API routes

View file

@ -243,6 +243,7 @@ class InsertResponse(BaseModel):
status: Status of the operation (success, duplicated, partial_success, failure) status: Status of the operation (success, duplicated, partial_success, failure)
message: Detailed message describing the operation result message: Detailed message describing the operation result
track_id: Tracking ID for monitoring processing status track_id: Tracking ID for monitoring processing status
token_usage: Token usage statistics for the insertion (prompt_tokens, completion_tokens, total_tokens, call_count)
""" """
status: Literal["success", "duplicated", "partial_success", "failure"] = Field( status: Literal["success", "duplicated", "partial_success", "failure"] = Field(
@ -250,6 +251,10 @@ class InsertResponse(BaseModel):
) )
message: str = Field(description="Message describing the operation result") message: str = Field(description="Message describing the operation result")
track_id: str = Field(description="Tracking ID for monitoring processing status") track_id: str = Field(description="Tracking ID for monitoring processing status")
token_usage: Optional[Dict[str, int]] = Field(
default=None,
description="Token usage statistics for the insertion (prompt_tokens, completion_tokens, total_tokens, call_count)",
)
class Config: class Config:
json_schema_extra = { json_schema_extra = {
@ -257,10 +262,17 @@ class InsertResponse(BaseModel):
"status": "success", "status": "success",
"message": "File 'document.pdf' uploaded successfully. Processing will continue in background.", "message": "File 'document.pdf' uploaded successfully. Processing will continue in background.",
"track_id": "upload_20250729_170612_abc123", "track_id": "upload_20250729_170612_abc123",
"token_usage": {
"prompt_tokens": 1250,
"completion_tokens": 450,
"total_tokens": 1700,
"call_count": 3,
},
} }
} }
class ClearDocumentsResponse(BaseModel): class ClearDocumentsResponse(BaseModel):
"""Response model for document clearing operation """Response model for document clearing operation

View file

@ -6,7 +6,7 @@ import json
import logging import logging
from typing import Any, Dict, List, Literal, Optional from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException, Request, Response
from lightrag.base import QueryParam from lightrag.base import QueryParam
from lightrag.types import MetadataFilter from lightrag.types import MetadataFilter
from lightrag.api.utils_api import get_combined_auth_dependency from lightrag.api.utils_api import get_combined_auth_dependency
@ -157,6 +157,10 @@ class QueryResponse(BaseModel):
default=None, default=None,
description="Reference list (Disabled when include_references=False, /query/data always includes references.)", description="Reference list (Disabled when include_references=False, /query/data always includes references.)",
) )
token_usage: Optional[Dict[str, int]] = Field(
default=None,
description="Token usage statistics for the query (prompt_tokens, completion_tokens, total_tokens)",
)
class QueryDataResponse(BaseModel): class QueryDataResponse(BaseModel):
@ -168,6 +172,10 @@ class QueryDataResponse(BaseModel):
metadata: Dict[str, Any] = Field( metadata: Dict[str, Any] = Field(
description="Query metadata including mode, keywords, and processing information" description="Query metadata including mode, keywords, and processing information"
) )
token_usage: Optional[Dict[str, int]] = Field(
default=None,
description="Token usage statistics for the query (prompt_tokens, completion_tokens, total_tokens)",
)
class StreamChunkResponse(BaseModel): class StreamChunkResponse(BaseModel):
@ -183,9 +191,18 @@ class StreamChunkResponse(BaseModel):
error: Optional[str] = Field( error: Optional[str] = Field(
default=None, description="Error message if processing fails" default=None, description="Error message if processing fails"
) )
token_usage: Optional[Dict[str, int]] = Field(
default=None,
description="Token usage statistics for the entire query (only in final chunk)",
)
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60): def create_query_routes(
rag,
api_key: Optional[str] = None,
top_k: int = 60,
token_tracker: Optional[Any] = None,
):
combined_auth = get_combined_auth_dependency(api_key) combined_auth = get_combined_auth_dependency(api_key)
@router.post( @router.post(
@ -220,7 +237,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
}, },
"examples": { "examples": {
"with_references": { "with_references": {
"summary": "Response with references", "summary": "Response with references and token usage",
"description": "Example response when include_references=True", "description": "Example response when include_references=True",
"value": { "value": {
"response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.", "response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.",
@ -234,13 +251,25 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
"file_path": "/documents/machine_learning.txt", "file_path": "/documents/machine_learning.txt",
}, },
], ],
"token_usage": {
"prompt_tokens": 245,
"completion_tokens": 87,
"total_tokens": 332,
"call_count": 1,
},
}, },
}, },
"without_references": { "without_references": {
"summary": "Response without references", "summary": "Response without references but with token usage",
"description": "Example response when include_references=False", "description": "Example response when include_references=False",
"value": { "value": {
"response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving." "response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.",
"token_usage": {
"prompt_tokens": 245,
"completion_tokens": 87,
"total_tokens": 332,
"call_count": 1,
},
}, },
}, },
"different_modes": { "different_modes": {
@ -358,6 +387,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
- 500: Internal processing error (e.g., LLM service unavailable) - 500: Internal processing error (e.g., LLM service unavailable)
""" """
try: try:
# Reset token tracker at start of query if available
if token_tracker:
token_tracker.reset()
param = request.to_query_params( param = request.to_query_params(
False False
) # Ensure stream=False for non-streaming endpoint ) # Ensure stream=False for non-streaming endpoint
@ -376,11 +409,22 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
if not response_content: if not response_content:
response_content = "No relevant context found for the query." response_content = "No relevant context found for the query."
# Get token usage if available
token_usage = None
if token_tracker:
token_usage = token_tracker.get_usage()
# Return response with or without references based on request # Return response with or without references based on request
if request.include_references: if request.include_references:
return QueryResponse(response=response_content, references=references) return QueryResponse(
response=response_content,
references=references,
token_usage=token_usage,
)
else: else:
return QueryResponse(response=response_content, references=None) return QueryResponse(
response=response_content, references=None, token_usage=token_usage
)
except Exception as e: except Exception as e:
trace_exception(e) trace_exception(e)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@ -577,6 +621,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
Use streaming mode for real-time interfaces and non-streaming for batch processing. Use streaming mode for real-time interfaces and non-streaming for batch processing.
""" """
try: try:
# Reset token tracker at start of query if available
if token_tracker:
token_tracker.reset()
# Use the stream parameter from the request, defaulting to True if not specified # Use the stream parameter from the request, defaulting to True if not specified
stream_mode = request.stream if request.stream is not None else True stream_mode = request.stream if request.stream is not None else True
param = request.to_query_params(stream_mode) param = request.to_query_params(stream_mode)
@ -605,6 +653,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
except Exception as e: except Exception as e:
logging.error(f"Streaming error: {str(e)}") logging.error(f"Streaming error: {str(e)}")
yield f"{json.dumps({'error': str(e)})}\n" yield f"{json.dumps({'error': str(e)})}\n"
# Add final token usage chunk if streaming and token tracker is available
if token_tracker and llm_response.get("is_streaming"):
yield f"{json.dumps({'token_usage': token_tracker.get_usage()})}\n"
else: else:
# Non-streaming mode: send complete response in one message # Non-streaming mode: send complete response in one message
response_content = llm_response.get("content", "") response_content = llm_response.get("content", "")
@ -616,6 +668,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
if request.include_references: if request.include_references:
complete_response["references"] = references complete_response["references"] = references
# Add token usage if available
if token_tracker:
complete_response["token_usage"] = token_tracker.get_usage()
yield f"{json.dumps(complete_response)}\n" yield f"{json.dumps(complete_response)}\n"
return StreamingResponse( return StreamingResponse(
@ -1022,18 +1078,31 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
as structured data analysis typically requires source attribution. as structured data analysis typically requires source attribution.
""" """
try: try:
# Reset token tracker at start of query if available
if token_tracker:
token_tracker.reset()
param = request.to_query_params(False) # No streaming for data endpoint param = request.to_query_params(False) # No streaming for data endpoint
response = await rag.aquery_data(request.query, param=param) response = await rag.aquery_data(request.query, param=param)
# Get token usage if available
token_usage = None
if token_tracker:
token_usage = token_tracker.get_usage()
# aquery_data returns the new format with status, message, data, and metadata # aquery_data returns the new format with status, message, data, and metadata
if isinstance(response, dict): if isinstance(response, dict):
return QueryDataResponse(**response) response_dict = dict(response)
response_dict["token_usage"] = token_usage
return QueryDataResponse(**response_dict)
else: else:
# Handle unexpected response format # Handle unexpected response format
return QueryDataResponse( return QueryDataResponse(
status="failure", status="failure",
message="Invalid response type", message="Unexpected response format",
data={}, data={},
metadata={},
token_usage=None,
) )
except Exception as e: except Exception as e:
trace_exception(e) trace_exception(e)

View file

@ -870,7 +870,8 @@ class LightRAG:
ids: str | list[str] | None = None, ids: str | list[str] | None = None,
file_paths: str | list[str] | None = None, file_paths: str | list[str] | None = None,
track_id: str | None = None, track_id: str | None = None,
) -> str: token_tracker: TokenTracker | None = None,
) -> tuple[str, dict]:
"""Sync Insert documents with checkpoint support """Sync Insert documents with checkpoint support
Args: Args:
@ -884,10 +885,10 @@ class LightRAG:
track_id: tracking ID for monitoring processing status, if not provided, will be generated track_id: tracking ID for monitoring processing status, if not provided, will be generated
Returns: Returns:
str: tracking ID for monitoring processing status tuple[str, dict]: (tracking ID for monitoring processing status, token usage statistics)
""" """
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete( result = loop.run_until_complete(
self.ainsert( self.ainsert(
input, input,
split_by_character, split_by_character,
@ -895,8 +896,10 @@ class LightRAG:
ids, ids,
file_paths, file_paths,
track_id, track_id,
token_tracker,
) )
) )
return result
async def ainsert( async def ainsert(
self, self,
@ -906,7 +909,8 @@ class LightRAG:
ids: str | list[str] | None = None, ids: str | list[str] | None = None,
file_paths: str | list[str] | None = None, file_paths: str | list[str] | None = None,
track_id: str | None = None, track_id: str | None = None,
) -> str: token_tracker: TokenTracker | None = None,
) -> tuple[str, dict]:
"""Async Insert documents with checkpoint support """Async Insert documents with checkpoint support
Args: Args:
@ -930,9 +934,15 @@ class LightRAG:
await self.apipeline_process_enqueue_documents( await self.apipeline_process_enqueue_documents(
split_by_character, split_by_character,
split_by_character_only, split_by_character_only,
token_tracker,
) )
return track_id return track_id, token_tracker.get_usage() if token_tracker else {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"call_count": 0,
}
# TODO: deprecated, use insert instead # TODO: deprecated, use insert instead
def insert_custom_chunks( def insert_custom_chunks(
@ -1107,7 +1117,9 @@ class LightRAG:
"file_path" "file_path"
], # Store file path in document status ], # Store file path in document status
"track_id": track_id, # Store track_id in document status "track_id": track_id, # Store track_id in document status
"metadata": metadata[i] if isinstance(metadata, list) and i < len(metadata) else metadata, # added provided custom metadata per document "metadata": metadata[i]
if isinstance(metadata, list) and i < len(metadata)
else metadata, # added provided custom metadata per document
} }
for i, (id_, content_data) in enumerate(contents.items()) for i, (id_, content_data) in enumerate(contents.items())
} }
@ -1363,6 +1375,7 @@ class LightRAG:
self, self,
split_by_character: str | None = None, split_by_character: str | None = None,
split_by_character_only: bool = False, split_by_character_only: bool = False,
token_tracker: TokenTracker | None = None,
) -> None: ) -> None:
""" """
Process pending documents by splitting them into chunks, processing Process pending documents by splitting them into chunks, processing
@ -1484,9 +1497,12 @@ class LightRAG:
pipeline_status: dict, pipeline_status: dict,
pipeline_status_lock: asyncio.Lock, pipeline_status_lock: asyncio.Lock,
semaphore: asyncio.Semaphore, semaphore: asyncio.Semaphore,
token_tracker: TokenTracker | None = None,
) -> None: ) -> None:
"""Process single document""" """Process single document"""
doc_metadata = getattr(status_doc, "metadata", None) doc_metadata = getattr(status_doc, "metadata", None)
if doc_metadata is None:
doc_metadata = {}
file_extraction_stage_ok = False file_extraction_stage_ok = False
async with semaphore: async with semaphore:
nonlocal processed_count nonlocal processed_count
@ -1498,7 +1514,6 @@ class LightRAG:
# Get file path from status document # Get file path from status document
file_path = getattr( file_path = getattr(
status_doc, "file_path", "unknown_source" status_doc, "file_path", "unknown_source"
) )
async with pipeline_status_lock: async with pipeline_status_lock:
@ -1561,7 +1576,9 @@ class LightRAG:
# Process document in two stages # Process document in two stages
# Stage 1: Process text chunks and docs (parallel execution) # Stage 1: Process text chunks and docs (parallel execution)
doc_metadata["processing_start_time"] = processing_start_time doc_metadata["processing_start_time"] = (
processing_start_time
)
doc_status_task = asyncio.create_task( doc_status_task = asyncio.create_task(
self.doc_status.upsert( self.doc_status.upsert(
@ -1610,6 +1627,7 @@ class LightRAG:
doc_metadata, doc_metadata,
pipeline_status, pipeline_status,
pipeline_status_lock, pipeline_status_lock,
token_tracker,
) )
) )
await entity_relation_task await entity_relation_task
@ -1643,7 +1661,9 @@ class LightRAG:
processing_end_time = int(time.time()) processing_end_time = int(time.time())
# Update document status to failed # Update document status to failed
doc_metadata["processing_start_time"] = processing_start_time doc_metadata["processing_start_time"] = (
processing_start_time
)
doc_metadata["processing_end_time"] = processing_end_time doc_metadata["processing_end_time"] = processing_end_time
await self.doc_status.upsert( await self.doc_status.upsert(
{ {
@ -1691,7 +1711,9 @@ class LightRAG:
doc_metadata["processing_start_time"] = ( doc_metadata["processing_start_time"] = (
processing_start_time processing_start_time
) )
doc_metadata["processing_end_time"] = processing_end_time doc_metadata["processing_end_time"] = (
processing_end_time
)
await self.doc_status.upsert( await self.doc_status.upsert(
{ {
@ -1748,7 +1770,9 @@ class LightRAG:
doc_metadata["processing_start_time"] = ( doc_metadata["processing_start_time"] = (
processing_start_time processing_start_time
) )
doc_metadata["processing_end_time"] = processing_end_time doc_metadata["processing_end_time"] = (
processing_end_time
)
await self.doc_status.upsert( await self.doc_status.upsert(
{ {
@ -1778,6 +1802,7 @@ class LightRAG:
pipeline_status, pipeline_status,
pipeline_status_lock, pipeline_status_lock,
semaphore, semaphore,
token_tracker,
) )
) )
@ -1827,6 +1852,7 @@ class LightRAG:
metadata: dict | None, metadata: dict | None,
pipeline_status=None, pipeline_status=None,
pipeline_status_lock=None, pipeline_status_lock=None,
token_tracker: TokenTracker | None = None,
) -> list: ) -> list:
try: try:
chunk_results = await extract_entities( chunk_results = await extract_entities(
@ -1837,6 +1863,7 @@ class LightRAG:
pipeline_status_lock=pipeline_status_lock, pipeline_status_lock=pipeline_status_lock,
llm_response_cache=self.llm_response_cache, llm_response_cache=self.llm_response_cache,
text_chunks_storage=self.text_chunks, text_chunks_storage=self.text_chunks,
token_tracker=token_tracker,
) )
return chunk_results return chunk_results
except Exception as e: except Exception as e:
@ -2061,14 +2088,18 @@ class LightRAG:
self, self,
query: str, query: str,
param: QueryParam = QueryParam(), param: QueryParam = QueryParam(),
token_tracker: TokenTracker | None = None,
system_prompt: str | None = None, system_prompt: str | None = None,
) -> str | Iterator[str]: ) -> Any:
""" """
Perform a sync query. User query interface (backward compatibility wrapper).
Delegates to aquery() for asynchronous execution and returns the result.
Args: Args:
query (str): The query to be executed. query (str): The query to be executed.
param (QueryParam): Configuration parameters for query execution. param (QueryParam): Configuration parameters for query execution.
token_tracker (TokenTracker | None): Optional token tracker for monitoring usage.
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"]. prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
Returns: Returns:
@ -2076,14 +2107,17 @@ class LightRAG:
""" """
loop = always_get_an_event_loop() loop = always_get_an_event_loop()
return loop.run_until_complete(self.aquery(query, param, system_prompt)) # type: ignore return loop.run_until_complete(
self.aquery(query, param, token_tracker, system_prompt)
) # type: ignore
async def aquery( async def aquery(
self, self,
query: str, query: str,
param: QueryParam = QueryParam(), param: QueryParam = QueryParam(),
token_tracker: TokenTracker | None = None,
system_prompt: str | None = None, system_prompt: str | None = None,
) -> str | AsyncIterator[str]: ) -> Any:
""" """
Perform a async query (backward compatibility wrapper). Perform a async query (backward compatibility wrapper).
@ -2102,7 +2136,7 @@ class LightRAG:
- Streaming: Returns AsyncIterator[str] - Streaming: Returns AsyncIterator[str]
""" """
# Call the new aquery_llm function to get complete results # Call the new aquery_llm function to get complete results
result = await self.aquery_llm(query, param, system_prompt) result = await self.aquery_llm(query, param, system_prompt, token_tracker)
# Extract and return only the LLM response for backward compatibility # Extract and return only the LLM response for backward compatibility
llm_response = result.get("llm_response", {}) llm_response = result.get("llm_response", {})
@ -2137,6 +2171,7 @@ class LightRAG:
self, self,
query: str, query: str,
param: QueryParam = QueryParam(), param: QueryParam = QueryParam(),
token_tracker: TokenTracker | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Asynchronous data retrieval API: returns structured retrieval results without LLM generation. Asynchronous data retrieval API: returns structured retrieval results without LLM generation.
@ -2330,6 +2365,7 @@ class LightRAG:
query: str, query: str,
param: QueryParam = QueryParam(), param: QueryParam = QueryParam(),
system_prompt: str | None = None, system_prompt: str | None = None,
token_tracker: TokenTracker | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Asynchronous complete query API: returns structured retrieval results with LLM generation. Asynchronous complete query API: returns structured retrieval results with LLM generation.
@ -2364,6 +2400,7 @@ class LightRAG:
hashing_kv=self.llm_response_cache, hashing_kv=self.llm_response_cache,
system_prompt=system_prompt, system_prompt=system_prompt,
chunks_vdb=self.chunks_vdb, chunks_vdb=self.chunks_vdb,
token_tracker=token_tracker,
) )
elif param.mode == "naive": elif param.mode == "naive":
query_result = await naive_query( query_result = await naive_query(
@ -2373,6 +2410,7 @@ class LightRAG:
global_config, global_config,
hashing_kv=self.llm_response_cache, hashing_kv=self.llm_response_cache,
system_prompt=system_prompt, system_prompt=system_prompt,
token_tracker=token_tracker,
) )
elif param.mode == "bypass": elif param.mode == "bypass":
# Bypass mode: directly use LLM without knowledge retrieval # Bypass mode: directly use LLM without knowledge retrieval

View file

@ -46,6 +46,7 @@ async def azure_openai_complete_if_cache(
base_url: str | None = None, base_url: str | None = None,
api_key: str | None = None, api_key: str | None = None,
api_version: str | None = None, api_version: str | None = None,
token_tracker: Any | None = None,
**kwargs, **kwargs,
): ):
if enable_cot: if enable_cot:
@ -94,28 +95,73 @@ async def azure_openai_complete_if_cache(
) )
if hasattr(response, "__aiter__"): if hasattr(response, "__aiter__"):
final_chunk_usage = None
accumulated_response = ""
async def inner(): async def inner():
nonlocal final_chunk_usage, accumulated_response
try:
async for chunk in response: async for chunk in response:
if len(chunk.choices) == 0: if len(chunk.choices) == 0:
continue continue
content = chunk.choices[0].delta.content content = chunk.choices[0].delta.content
if content is None: if content is None:
continue continue
accumulated_response += content
if r"\u" in content: if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8")) content = safe_unicode_decode(content.encode("utf-8"))
yield content yield content
# Check for usage in the last chunk
if hasattr(chunk, "usage") and chunk.usage is not None:
final_chunk_usage = chunk.usage
except Exception as e:
logger.error(f"Error in Azure OpenAI stream response: {str(e)}")
raise
finally:
# After streaming is complete, track token usage
if token_tracker and final_chunk_usage:
# Use actual usage from the API
token_counts = {
"prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0),
"completion_tokens": getattr(
final_chunk_usage, "completion_tokens", 0
),
"total_tokens": getattr(final_chunk_usage, "total_tokens", 0),
}
token_tracker.add_usage(token_counts)
logger.debug(f"Azure OpenAI streaming token usage: {token_counts}")
elif token_tracker:
logger.debug(
"No usage information available in Azure OpenAI streaming response"
)
return inner() return inner()
else: else:
content = response.choices[0].message.content content = response.choices[0].message.content
if r"\u" in content: if r"\u" in content:
content = safe_unicode_decode(content.encode("utf-8")) content = safe_unicode_decode(content.encode("utf-8"))
# Track token usage for non-streaming response
if token_tracker and hasattr(response, "usage"):
token_counts = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
"completion_tokens": getattr(response.usage, "completion_tokens", 0),
"total_tokens": getattr(response.usage, "total_tokens", 0),
}
token_tracker.add_usage(token_counts)
logger.debug(f"Azure OpenAI non-streaming token usage: {token_counts}")
return content return content
async def azure_openai_complete( async def azure_openai_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs prompt,
system_prompt=None,
history_messages=[],
keyword_extraction=False,
token_tracker=None,
**kwargs,
) -> str: ) -> str:
kwargs.pop("keyword_extraction", None) kwargs.pop("keyword_extraction", None)
result = await azure_openai_complete_if_cache( result = await azure_openai_complete_if_cache(
@ -123,6 +169,7 @@ async def azure_openai_complete(
prompt, prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,
token_tracker=token_tracker,
**kwargs, **kwargs,
) )
return result return result
@ -142,6 +189,7 @@ async def azure_openai_embed(
base_url: str | None = None, base_url: str | None = None,
api_key: str | None = None, api_key: str | None = None,
api_version: str | None = None, api_version: str | None = None,
token_tracker: Any | None = None,
) -> np.ndarray: ) -> np.ndarray:
deployment = ( deployment = (
os.getenv("AZURE_EMBEDDING_DEPLOYMENT") os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
@ -174,4 +222,14 @@ async def azure_openai_embed(
response = await openai_async_client.embeddings.create( response = await openai_async_client.embeddings.create(
model=model, input=texts, encoding_format="float" model=model, input=texts, encoding_format="float"
) )
# Track token usage for embeddings if token tracker is provided
if token_tracker and hasattr(response, "usage"):
token_counts = {
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
"total_tokens": getattr(response.usage, "total_tokens", 0),
}
token_tracker.add_usage(token_counts)
logger.debug(f"Azure OpenAI embedding token usage: {token_counts}")
return np.array([dp.embedding for dp in response.data]) return np.array([dp.embedding for dp in response.data])

View file

@ -48,6 +48,7 @@ async def bedrock_complete_if_cache(
aws_access_key_id=None, aws_access_key_id=None,
aws_secret_access_key=None, aws_secret_access_key=None,
aws_session_token=None, aws_session_token=None,
token_tracker=None,
**kwargs, **kwargs,
) -> Union[str, AsyncIterator[str]]: ) -> Union[str, AsyncIterator[str]]:
if enable_cot: if enable_cot:
@ -155,6 +156,18 @@ async def bedrock_complete_if_cache(
yield text yield text
# Handle other event types that might indicate stream end # Handle other event types that might indicate stream end
elif "messageStop" in event: elif "messageStop" in event:
# Track token usage for streaming if token tracker is provided
if token_tracker and "usage" in event:
usage = event["usage"]
token_counts = {
"prompt_tokens": usage.get("inputTokens", 0),
"completion_tokens": usage.get("outputTokens", 0),
"total_tokens": usage.get("totalTokens", 0),
}
token_tracker.add_usage(token_counts)
logging.debug(
f"Bedrock streaming token usage: {token_counts}"
)
break break
except Exception as e: except Exception as e:
@ -228,6 +241,17 @@ async def bedrock_complete_if_cache(
if not content or content.strip() == "": if not content or content.strip() == "":
raise BedrockError("Received empty content from Bedrock API") raise BedrockError("Received empty content from Bedrock API")
# Track token usage for non-streaming if token tracker is provided
if token_tracker and "usage" in response:
usage = response["usage"]
token_counts = {
"prompt_tokens": usage.get("inputTokens", 0),
"completion_tokens": usage.get("outputTokens", 0),
"total_tokens": usage.get("totalTokens", 0),
}
token_tracker.add_usage(token_counts)
logging.debug(f"Bedrock non-streaming token usage: {token_counts}")
return content return content
except Exception as e: except Exception as e:
@ -239,7 +263,12 @@ async def bedrock_complete_if_cache(
# Generic Bedrock completion function # Generic Bedrock completion function
async def bedrock_complete( async def bedrock_complete(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs prompt,
system_prompt=None,
history_messages=[],
keyword_extraction=False,
token_tracker=None,
**kwargs,
) -> Union[str, AsyncIterator[str]]: ) -> Union[str, AsyncIterator[str]]:
kwargs.pop("keyword_extraction", None) kwargs.pop("keyword_extraction", None)
model_name = kwargs["hashing_kv"].global_config["llm_model_name"] model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
@ -248,6 +277,7 @@ async def bedrock_complete(
prompt, prompt,
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,
token_tracker=token_tracker,
**kwargs, **kwargs,
) )
return result return result
@ -265,6 +295,7 @@ async def bedrock_embed(
aws_access_key_id=None, aws_access_key_id=None,
aws_secret_access_key=None, aws_secret_access_key=None,
aws_session_token=None, aws_session_token=None,
token_tracker=None,
) -> np.ndarray: ) -> np.ndarray:
# Respect existing env; only set if a non-empty value is available # Respect existing env; only set if a non-empty value is available
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id

View file

@ -108,6 +108,7 @@ async def lollms_model_complete(
history_messages=[], history_messages=[],
enable_cot: bool = False, enable_cot: bool = False,
keyword_extraction=False, keyword_extraction=False,
token_tracker=None,
**kwargs, **kwargs,
) -> Union[str, AsyncIterator[str]]: ) -> Union[str, AsyncIterator[str]]:
"""Complete function for lollms model generation.""" """Complete function for lollms model generation."""
@ -135,7 +136,11 @@ async def lollms_model_complete(
async def lollms_embed( async def lollms_embed(
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs texts: List[str],
embed_model=None,
base_url="http://localhost:9600",
token_tracker=None,
**kwargs,
) -> np.ndarray: ) -> np.ndarray:
""" """
Generate embeddings for a list of texts using lollms server. Generate embeddings for a list of texts using lollms server.
@ -144,6 +149,7 @@ async def lollms_embed(
texts: List of strings to embed texts: List of strings to embed
embed_model: Model name (not used directly as lollms uses configured vectorizer) embed_model: Model name (not used directly as lollms uses configured vectorizer)
base_url: URL of the lollms server base_url: URL of the lollms server
token_tracker: Optional token usage tracker for monitoring API usage
**kwargs: Additional arguments passed to the request **kwargs: Additional arguments passed to the request
Returns: Returns:

View file

@ -39,6 +39,7 @@ async def _ollama_model_if_cache(
system_prompt=None, system_prompt=None,
history_messages=[], history_messages=[],
enable_cot: bool = False, enable_cot: bool = False,
token_tracker=None,
**kwargs, **kwargs,
) -> Union[str, AsyncIterator[str]]: ) -> Union[str, AsyncIterator[str]]:
if enable_cot: if enable_cot:
@ -74,13 +75,47 @@ async def _ollama_model_if_cache(
"""cannot cache stream response and process reasoning""" """cannot cache stream response and process reasoning"""
async def inner(): async def inner():
accumulated_response = ""
try: try:
async for chunk in response: async for chunk in response:
yield chunk["message"]["content"] chunk_content = chunk["message"]["content"]
accumulated_response += chunk_content
yield chunk_content
except Exception as e: except Exception as e:
logger.error(f"Error in stream response: {str(e)}") logger.error(f"Error in stream response: {str(e)}")
raise raise
finally: finally:
# Track token usage for streaming if token tracker is provided
if token_tracker:
# Estimate prompt tokens: roughly 4 characters per token for English text
prompt_text = ""
if system_prompt:
prompt_text += system_prompt + " "
prompt_text += (
" ".join(
[msg.get("content", "") for msg in history_messages]
)
+ " "
)
prompt_text += prompt
prompt_tokens = len(prompt_text) // 4 + (
1 if len(prompt_text) % 4 else 0
)
# Estimate completion tokens from accumulated response
completion_tokens = len(accumulated_response) // 4 + (
1 if len(accumulated_response) % 4 else 0
)
total_tokens = prompt_tokens + completion_tokens
token_tracker.add_usage(
{
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
)
try: try:
await ollama_client._client.aclose() await ollama_client._client.aclose()
logger.debug("Successfully closed Ollama client for streaming") logger.debug("Successfully closed Ollama client for streaming")
@ -91,6 +126,35 @@ async def _ollama_model_if_cache(
else: else:
model_response = response["message"]["content"] model_response = response["message"]["content"]
# Track token usage if token tracker is provided
# Note: Ollama doesn't provide token usage in chat responses, so we estimate
if token_tracker:
# Estimate prompt tokens: roughly 4 characters per token for English text
prompt_text = ""
if system_prompt:
prompt_text += system_prompt + " "
prompt_text += (
" ".join([msg.get("content", "") for msg in history_messages]) + " "
)
prompt_text += prompt
prompt_tokens = len(prompt_text) // 4 + (
1 if len(prompt_text) % 4 else 0
)
# Estimate completion tokens from response
completion_tokens = len(model_response) // 4 + (
1 if len(model_response) % 4 else 0
)
total_tokens = prompt_tokens + completion_tokens
token_tracker.add_usage(
{
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
}
)
""" """
If the model also wraps its thoughts in a specific tag, If the model also wraps its thoughts in a specific tag,
this information is not needed for the final this information is not needed for the final
@ -126,6 +190,7 @@ async def ollama_model_complete(
history_messages=[], history_messages=[],
enable_cot: bool = False, enable_cot: bool = False,
keyword_extraction=False, keyword_extraction=False,
token_tracker=None,
**kwargs, **kwargs,
) -> Union[str, AsyncIterator[str]]: ) -> Union[str, AsyncIterator[str]]:
keyword_extraction = kwargs.pop("keyword_extraction", None) keyword_extraction = kwargs.pop("keyword_extraction", None)
@ -138,11 +203,14 @@ async def ollama_model_complete(
system_prompt=system_prompt, system_prompt=system_prompt,
history_messages=history_messages, history_messages=history_messages,
enable_cot=enable_cot, enable_cot=enable_cot,
token_tracker=token_tracker,
**kwargs, **kwargs,
) )
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: async def ollama_embed(
texts: list[str], embed_model, token_tracker=None, **kwargs
) -> np.ndarray:
api_key = kwargs.pop("api_key", None) api_key = kwargs.pop("api_key", None)
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
@ -160,6 +228,21 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
data = await ollama_client.embed( data = await ollama_client.embed(
model=embed_model, input=texts, options=options model=embed_model, input=texts, options=options
) )
# Track token usage if token tracker is provided
# Note: Ollama doesn't provide token usage in embedding responses, so we estimate
if token_tracker:
# Estimate tokens: roughly 4 characters per token for English text
total_chars = sum(len(text) for text in texts)
estimated_tokens = total_chars // 4 + (1 if total_chars % 4 else 0)
token_tracker.add_usage(
{
"prompt_tokens": estimated_tokens,
"completion_tokens": 0,
"total_tokens": estimated_tokens,
}
)
return np.array(data["embeddings"]) return np.array(data["embeddings"])
except Exception as e: except Exception as e:
logger.error(f"Error in ollama_embed: {str(e)}") logger.error(f"Error in ollama_embed: {str(e)}")

View file

@ -12,6 +12,7 @@ from .utils import (
logger, logger,
compute_mdhash_id, compute_mdhash_id,
Tokenizer, Tokenizer,
TokenTracker,
is_float_regex, is_float_regex,
sanitize_and_normalize_extracted_text, sanitize_and_normalize_extracted_text,
pack_user_ass_to_openai_messages, pack_user_ass_to_openai_messages,
@ -126,6 +127,7 @@ async def _handle_entity_relation_summary(
seperator: str, seperator: str,
global_config: dict, global_config: dict,
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
token_tracker: TokenTracker | None = None,
) -> tuple[str, bool]: ) -> tuple[str, bool]:
"""Handle entity relation description summary using map-reduce approach. """Handle entity relation description summary using map-reduce approach.
@ -188,6 +190,7 @@ async def _handle_entity_relation_summary(
current_list, current_list,
global_config, global_config,
llm_response_cache, llm_response_cache,
token_tracker,
) )
return final_summary, True # LLM was used for final summarization return final_summary, True # LLM was used for final summarization
@ -243,6 +246,7 @@ async def _handle_entity_relation_summary(
chunk, chunk,
global_config, global_config,
llm_response_cache, llm_response_cache,
token_tracker,
) )
new_summaries.append(summary) new_summaries.append(summary)
llm_was_used = True # Mark that LLM was used in reduce phase llm_was_used = True # Mark that LLM was used in reduce phase
@ -257,6 +261,7 @@ async def _summarize_descriptions(
description_list: list[str], description_list: list[str],
global_config: dict, global_config: dict,
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
token_tracker: TokenTracker | None = None,
) -> str: ) -> str:
"""Helper function to summarize a list of descriptions using LLM. """Helper function to summarize a list of descriptions using LLM.
@ -312,9 +317,10 @@ async def _summarize_descriptions(
# Use LLM function with cache (higher priority for summary generation) # Use LLM function with cache (higher priority for summary generation)
summary, _ = await use_llm_func_with_cache( summary, _ = await use_llm_func_with_cache(
use_prompt, use_prompt,
use_llm_func, use_llm_func=use_llm_func,
llm_response_cache=llm_response_cache, hashing_kv=llm_response_cache,
cache_type="summary", cache_type="summary",
token_tracker=token_tracker,
) )
return summary return summary
@ -405,7 +411,7 @@ async def _handle_single_relationship_extraction(
): # treat "relationship" and "relation" interchangeable ): # treat "relationship" and "relation" interchangeable
if len(record_attributes) > 1 and "relation" in record_attributes[0]: if len(record_attributes) > 1 and "relation" in record_attributes[0]:
logger.warning( logger.warning(
f"{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on REALTION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) >2 else 'N/A'}`" f"{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on REALTION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) > 2 else 'N/A'}`"
) )
logger.debug(record_attributes) logger.debug(record_attributes)
return None return None
@ -463,7 +469,6 @@ async def _handle_single_relationship_extraction(
file_path=file_path, file_path=file_path,
timestamp=timestamp, timestamp=timestamp,
metadata=metadata, metadata=metadata,
) )
except ValueError as e: except ValueError as e:
@ -2037,6 +2042,7 @@ async def extract_entities(
pipeline_status_lock=None, pipeline_status_lock=None,
llm_response_cache: BaseKVStorage | None = None, llm_response_cache: BaseKVStorage | None = None,
text_chunks_storage: BaseKVStorage | None = None, text_chunks_storage: BaseKVStorage | None = None,
token_tracker: TokenTracker | None = None,
) -> list: ) -> list:
use_llm_func: callable = global_config["llm_model_func"] use_llm_func: callable = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@ -2150,12 +2156,13 @@ async def extract_entities(
final_result, timestamp = await use_llm_func_with_cache( final_result, timestamp = await use_llm_func_with_cache(
entity_extraction_user_prompt, entity_extraction_user_prompt,
use_llm_func, use_llm_func=use_llm_func,
system_prompt=entity_extraction_system_prompt, system_prompt=entity_extraction_system_prompt,
llm_response_cache=llm_response_cache, hashing_kv=llm_response_cache,
cache_type="extract", cache_type="extract",
chunk_id=chunk_key, chunk_id=chunk_key,
cache_keys_collector=cache_keys_collector, cache_keys_collector=cache_keys_collector,
token_tracker=token_tracker,
) )
history = pack_user_ass_to_openai_messages( history = pack_user_ass_to_openai_messages(
@ -2177,16 +2184,16 @@ async def extract_entities(
if entity_extract_max_gleaning > 0: if entity_extract_max_gleaning > 0:
glean_result, timestamp = await use_llm_func_with_cache( glean_result, timestamp = await use_llm_func_with_cache(
entity_continue_extraction_user_prompt, entity_continue_extraction_user_prompt,
use_llm_func, use_llm_func=use_llm_func,
system_prompt=entity_extraction_system_prompt, system_prompt=entity_extraction_system_prompt,
llm_response_cache=llm_response_cache, llm_response_cache=llm_response_cache,
history_messages=history, history_messages=history,
cache_type="extract", cache_type="extract",
chunk_id=chunk_key, chunk_id=chunk_key,
cache_keys_collector=cache_keys_collector, cache_keys_collector=cache_keys_collector,
token_tracker=token_tracker,
) )
# Process gleaning result separately with file path and metadata # Process gleaning result separately with file path and metadata
glean_nodes, glean_edges = await _process_extraction_result( glean_nodes, glean_edges = await _process_extraction_result(
glean_result, glean_result,
@ -2300,7 +2307,7 @@ async def extract_entities(
await asyncio.wait(pending) await asyncio.wait(pending)
# Add progress prefix to the exception message # Add progress prefix to the exception message
progress_prefix = f"C[{processed_chunks+1}/{total_chunks}]" progress_prefix = f"C[{processed_chunks + 1}/{total_chunks}]"
# Re-raise the original exception with a prefix # Re-raise the original exception with a prefix
prefixed_exception = create_prefixed_exception(first_exception, progress_prefix) prefixed_exception = create_prefixed_exception(first_exception, progress_prefix)
@ -2324,6 +2331,7 @@ async def kg_query(
system_prompt: str | None = None, system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None, chunks_vdb: BaseVectorStorage = None,
return_raw_data: Literal[True] = False, return_raw_data: Literal[True] = False,
token_tracker: TokenTracker | None = None,
) -> dict[str, Any]: ... ) -> dict[str, Any]: ...
@ -2341,6 +2349,7 @@ async def kg_query(
chunks_vdb: BaseVectorStorage = None, chunks_vdb: BaseVectorStorage = None,
metadata_filters: list | None = None, metadata_filters: list | None = None,
return_raw_data: Literal[False] = False, return_raw_data: Literal[False] = False,
token_tracker: TokenTracker | None = None,
) -> str | AsyncIterator[str]: ... ) -> str | AsyncIterator[str]: ...
@ -2355,6 +2364,7 @@ async def kg_query(
hashing_kv: BaseKVStorage | None = None, hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None, system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None, chunks_vdb: BaseVectorStorage = None,
token_tracker: TokenTracker | None = None,
) -> QueryResult: ) -> QueryResult:
""" """
Execute knowledge graph query and return unified QueryResult object. Execute knowledge graph query and return unified QueryResult object.
@ -2422,7 +2432,7 @@ async def kg_query(
return QueryResult(content=cached_response) return QueryResult(content=cached_response)
hl_keywords, ll_keywords = await get_keywords_from_query( hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv query, query_param, global_config, hashing_kv, token_tracker
) )
logger.debug(f"High-level keywords: {hl_keywords}") logger.debug(f"High-level keywords: {hl_keywords}")
@ -2526,6 +2536,7 @@ async def kg_query(
history_messages=query_param.conversation_history, history_messages=query_param.conversation_history,
enable_cot=True, enable_cot=True,
stream=query_param.stream, stream=query_param.stream,
token_tracker=token_tracker,
) )
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
@ -2583,6 +2594,7 @@ async def get_keywords_from_query(
query_param: QueryParam, query_param: QueryParam,
global_config: dict[str, str], global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None, hashing_kv: BaseKVStorage | None = None,
token_tracker: TokenTracker | None = None,
) -> tuple[list[str], list[str]]: ) -> tuple[list[str], list[str]]:
""" """
Retrieves high-level and low-level keywords for RAG operations. Retrieves high-level and low-level keywords for RAG operations.
@ -2605,7 +2617,7 @@ async def get_keywords_from_query(
# Extract keywords using extract_keywords_only function which already supports conversation history # Extract keywords using extract_keywords_only function which already supports conversation history
hl_keywords, ll_keywords = await extract_keywords_only( hl_keywords, ll_keywords = await extract_keywords_only(
query, query_param, global_config, hashing_kv query, query_param, global_config, hashing_kv, token_tracker
) )
return hl_keywords, ll_keywords return hl_keywords, ll_keywords
@ -2615,6 +2627,7 @@ async def extract_keywords_only(
param: QueryParam, param: QueryParam,
global_config: dict[str, str], global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None, hashing_kv: BaseKVStorage | None = None,
token_tracker: TokenTracker | None = None,
) -> tuple[list[str], list[str]]: ) -> tuple[list[str], list[str]]:
""" """
Extract high-level and low-level keywords from the given 'text' using the LLM. Extract high-level and low-level keywords from the given 'text' using the LLM.
@ -2668,7 +2681,9 @@ async def extract_keywords_only(
# Apply higher priority (5) to query relation LLM function # Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5) use_model_func = partial(use_model_func, _priority=5)
result = await use_model_func(kw_prompt, keyword_extraction=True) result = await use_model_func(
kw_prompt, keyword_extraction=True, token_tracker=token_tracker
)
# 5. Parse out JSON from the LLM response # 5. Parse out JSON from the LLM response
result = remove_think_tags(result) result = remove_think_tags(result)
@ -2746,7 +2761,10 @@ async def _get_vector_context(
cosine_threshold = chunks_vdb.cosine_better_than_threshold cosine_threshold = chunks_vdb.cosine_better_than_threshold
results = await chunks_vdb.query( results = await chunks_vdb.query(
query, top_k=search_top_k, query_embedding=query_embedding, metadata_filter=query_param.metadata_filter query,
top_k=search_top_k,
query_embedding=query_embedding,
metadata_filter=query_param.metadata_filter,
) )
if not results: if not results:
logger.info( logger.info(
@ -2763,7 +2781,7 @@ async def _get_vector_context(
"file_path": result.get("file_path", "unknown_source"), "file_path": result.get("file_path", "unknown_source"),
"source_type": "vector", # Mark the source type "source_type": "vector", # Mark the source type
"chunk_id": result.get("id"), # Add chunk_id for deduplication "chunk_id": result.get("id"), # Add chunk_id for deduplication
"metadata": result.get("metadata") "metadata": result.get("metadata"),
} }
valid_chunks.append(chunk_with_metadata) valid_chunks.append(chunk_with_metadata)
@ -3529,8 +3547,9 @@ async def _get_node_data(
f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})" f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})"
) )
results = await entities_vdb.query(
results = await entities_vdb.query(query, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter) query, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter
)
if not len(results): if not len(results):
return [], [] return [], []
@ -3538,7 +3557,6 @@ async def _get_node_data(
# Extract all entity IDs from your results list # Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results] node_ids = [r["entity_name"] for r in results]
# Extract all entity IDs from your results list # Extract all entity IDs from your results list
node_ids = [r["entity_name"] for r in results] node_ids = [r["entity_name"] for r in results]
@ -3810,7 +3828,9 @@ async def _get_edge_data(
f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})" f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})"
) )
results = await relationships_vdb.query(keywords, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter) results = await relationships_vdb.query(
keywords, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter
)
if not len(results): if not len(results):
return [], [] return [], []
@ -4104,6 +4124,7 @@ async def naive_query(
hashing_kv: BaseKVStorage | None = None, hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None, system_prompt: str | None = None,
return_raw_data: Literal[True] = True, return_raw_data: Literal[True] = True,
token_tracker: TokenTracker | None = None,
) -> dict[str, Any]: ... ) -> dict[str, Any]: ...
@ -4116,6 +4137,7 @@ async def naive_query(
hashing_kv: BaseKVStorage | None = None, hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None, system_prompt: str | None = None,
return_raw_data: Literal[False] = False, return_raw_data: Literal[False] = False,
token_tracker: TokenTracker | None = None,
) -> str | AsyncIterator[str]: ... ) -> str | AsyncIterator[str]: ...
@ -4126,6 +4148,7 @@ async def naive_query(
global_config: dict[str, str], global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None, hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None, system_prompt: str | None = None,
token_tracker: TokenTracker | None = None,
) -> QueryResult: ) -> QueryResult:
""" """
Execute naive query and return unified QueryResult object. Execute naive query and return unified QueryResult object.
@ -4321,6 +4344,7 @@ async def naive_query(
history_messages=query_param.conversation_history, history_messages=query_param.conversation_history,
enable_cot=True, enable_cot=True,
stream=query_param.stream, stream=query_param.stream,
token_tracker=token_tracker,
) )
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):

View file

@ -4,6 +4,7 @@ import weakref
import asyncio import asyncio
import html import html
import csv import csv
import contextvars
import json import json
import logging import logging
import logging.handlers import logging.handlers
@ -507,6 +508,7 @@ def priority_limit_async_func_call(
task_id, task_id,
args, args,
kwargs, kwargs,
ctx,
) = await asyncio.wait_for(queue.get(), timeout=1.0) ) = await asyncio.wait_for(queue.get(), timeout=1.0)
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
@ -536,11 +538,15 @@ def priority_limit_async_func_call(
try: try:
# Execute function with timeout protection # Execute function with timeout protection
if max_execution_timeout is not None: if max_execution_timeout is not None:
# Run the function in the captured context
task = ctx.run(lambda: asyncio.create_task(func(*args, **kwargs)))
result = await asyncio.wait_for( result = await asyncio.wait_for(
func(*args, **kwargs), timeout=max_execution_timeout task, timeout=max_execution_timeout
) )
else: else:
result = await func(*args, **kwargs) # Run the function in the captured context
task = ctx.run(lambda: asyncio.create_task(func(*args, **kwargs)))
result = await task
# Set result if future is still valid # Set result if future is still valid
if not task_state.future.done(): if not task_state.future.done():
@ -791,6 +797,9 @@ def priority_limit_async_func_call(
future=future, start_time=asyncio.get_event_loop().time() future=future, start_time=asyncio.get_event_loop().time()
) )
# Capture current context
ctx = contextvars.copy_context()
try: try:
# Register task state # Register task state
async with task_states_lock: async with task_states_lock:
@ -809,13 +818,13 @@ def priority_limit_async_func_call(
if _queue_timeout is not None: if _queue_timeout is not None:
await asyncio.wait_for( await asyncio.wait_for(
queue.put( queue.put(
(_priority, current_count, task_id, args, kwargs) (_priority, current_count, task_id, args, kwargs, ctx)
), ),
timeout=_queue_timeout, timeout=_queue_timeout,
) )
else: else:
await queue.put( await queue.put(
(_priority, current_count, task_id, args, kwargs) (_priority, current_count, task_id, args, kwargs, ctx)
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise QueueFullError( raise QueueFullError(
@ -1472,8 +1481,7 @@ async def aexport_data(
else: else:
raise ValueError( raise ValueError(
f"Unsupported file format: {file_format}. " f"Unsupported file format: {file_format}. Choose from: csv, excel, md, txt"
f"Choose from: csv, excel, md, txt"
) )
if file_format is not None: if file_format is not None:
print(f"Data exported to: {output_path} with format: {file_format}") print(f"Data exported to: {output_path} with format: {file_format}")
@ -1601,6 +1609,8 @@ async def use_llm_func_with_cache(
cache_type: str = "extract", cache_type: str = "extract",
chunk_id: str | None = None, chunk_id: str | None = None,
cache_keys_collector: list = None, cache_keys_collector: list = None,
hashing_kv: "BaseKVStorage | None" = None,
token_tracker=None,
) -> tuple[str, int]: ) -> tuple[str, int]:
"""Call LLM function with cache support and text sanitization """Call LLM function with cache support and text sanitization
@ -1685,7 +1695,10 @@ async def use_llm_func_with_cache(
kwargs["max_tokens"] = max_tokens kwargs["max_tokens"] = max_tokens
res: str = await use_llm_func( res: str = await use_llm_func(
safe_user_prompt, system_prompt=safe_system_prompt, **kwargs safe_user_prompt,
system_prompt=safe_system_prompt,
token_tracker=token_tracker,
**kwargs,
) )
res = remove_think_tags(res) res = remove_think_tags(res)
@ -1720,7 +1733,10 @@ async def use_llm_func_with_cache(
try: try:
res = await use_llm_func( res = await use_llm_func(
safe_user_prompt, system_prompt=safe_system_prompt, **kwargs safe_user_prompt,
system_prompt=safe_system_prompt,
token_tracker=token_tracker,
**kwargs,
) )
except Exception as e: except Exception as e:
# Add [LLM func] prefix to error message # Add [LLM func] prefix to error message
@ -2216,52 +2232,74 @@ async def pick_by_vector_similarity(
return all_chunk_ids[:num_of_chunks] return all_chunk_ids[:num_of_chunks]
from contextvars import ContextVar
class TokenTracker: class TokenTracker:
"""Track token usage for LLM calls.""" """Track token usage for LLM calls using ContextVars for concurrency support."""
_usage_var: ContextVar[dict] = ContextVar("token_usage", default=None)
def __init__(self): def __init__(self):
self.reset() # No instance state needed as we use ContextVar
pass
def __enter__(self): def __enter__(self):
self.reset() self.reset()
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
print(self) # Optional: Log usage on exit if needed
pass
def reset(self): def reset(self):
self.prompt_tokens = 0 """Initialize/Reset token usage for the current context."""
self.completion_tokens = 0 self._usage_var.set(
self.total_tokens = 0 {
self.call_count = 0 "prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"call_count": 0,
}
)
def add_usage(self, token_counts): def _get_current_usage(self) -> dict:
"""Get the usage dict for the current context, initializing if necessary."""
usage = self._usage_var.get()
if usage is None:
usage = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"call_count": 0,
}
self._usage_var.set(usage)
return usage
def add_usage(self, token_counts: dict):
"""Add token usage from one LLM call. """Add token usage from one LLM call.
Args: Args:
token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens
""" """
self.prompt_tokens += token_counts.get("prompt_tokens", 0) usage = self._get_current_usage()
self.completion_tokens += token_counts.get("completion_tokens", 0)
usage["prompt_tokens"] += token_counts.get("prompt_tokens", 0)
usage["completion_tokens"] += token_counts.get("completion_tokens", 0)
# If total_tokens is provided, use it directly; otherwise calculate the sum # If total_tokens is provided, use it directly; otherwise calculate the sum
if "total_tokens" in token_counts: if "total_tokens" in token_counts:
self.total_tokens += token_counts["total_tokens"] usage["total_tokens"] += token_counts["total_tokens"]
else: else:
self.total_tokens += token_counts.get( usage["total_tokens"] += token_counts.get(
"prompt_tokens", 0 "prompt_tokens", 0
) + token_counts.get("completion_tokens", 0) ) + token_counts.get("completion_tokens", 0)
self.call_count += 1 usage["call_count"] += 1
def get_usage(self): def get_usage(self):
"""Get current usage statistics.""" """Get current usage statistics."""
return { return self._get_current_usage().copy()
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"total_tokens": self.total_tokens,
"call_count": self.call_count,
}
def __str__(self): def __str__(self):
usage = self.get_usage() usage = self.get_usage()
@ -2273,6 +2311,26 @@ class TokenTracker:
) )
def estimate_embedding_tokens(texts: list[str], tokenizer: Tokenizer) -> int:
"""Estimate tokens for embedding operations based on text length.
Most embedding APIs don't return token counts, so we estimate based on
the tokenizer encoding. This provides a reasonable approximation for tracking.
Args:
texts: List of text strings to be embedded
tokenizer: Tokenizer instance for encoding
Returns:
Estimated total token count for all texts
"""
total = 0
for text in texts:
if text: # Skip empty strings
total += len(tokenizer.encode(text))
return total
async def apply_rerank_if_enabled( async def apply_rerank_if_enabled(
query: str, query: str,
retrieved_docs: list[dict], retrieved_docs: list[dict],