Add TokenTracker and update LLM call functions to track token usage

This commit is contained in:
Mohit Tilwani 2025-10-24 18:26:48 +02:00
parent 7dfc224be9
commit 646ed12c13
3 changed files with 146 additions and 19 deletions

View file

@ -91,6 +91,7 @@ from lightrag.operate import (
from lightrag.constants import GRAPH_FIELD_SEP
from lightrag.utils import (
Tokenizer,
TokenTracker,
TiktokenTokenizer,
EmbeddingFunc,
always_get_an_event_loop,
@ -106,6 +107,7 @@ from lightrag.utils import (
subtract_source_ids,
make_relation_chunk_key,
normalize_source_ids_limit_method,
call_llm_with_tracker,
)
from lightrag.types import KnowledgeGraph
from dotenv import load_dotenv
@ -2430,6 +2432,7 @@ class LightRAG:
fields at the top level.
"""
global_config = asdict(self)
token_tracker = TokenTracker()
# Create a copy of param to avoid modifying the original
data_param = QueryParam(
@ -2467,6 +2470,7 @@ class LightRAG:
hashing_kv=self.llm_response_cache,
system_prompt=None,
chunks_vdb=self.chunks_vdb,
token_tracker=token_tracker,
)
elif data_param.mode == "naive":
logger.debug(f"[aquery_data] Using naive_query for mode: {data_param.mode}")
@ -2477,6 +2481,7 @@ class LightRAG:
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=None,
token_tracker=token_tracker,
)
elif data_param.mode == "bypass":
logger.debug("[aquery_data] Using bypass mode")
@ -2522,6 +2527,14 @@ class LightRAG:
else:
logger.warning("[aquery_data] No data section found in query result")
usage = token_tracker.get_usage()
if isinstance(final_data, dict):
final_data.setdefault("metadata", {})
final_data["metadata"]["token_usage"] = usage
logger.info(
"[aquery_data] Token usage (mode=%s): %s", data_param.mode, usage
)
await self._query_done()
return final_data
@ -2548,6 +2561,7 @@ class LightRAG:
logger.debug(f"[aquery_llm] Query param: {param}")
global_config = asdict(self)
token_tracker = TokenTracker()
try:
query_result = None
@ -2564,6 +2578,7 @@ class LightRAG:
hashing_kv=self.llm_response_cache,
system_prompt=system_prompt,
chunks_vdb=self.chunks_vdb,
token_tracker=token_tracker,
)
elif param.mode == "naive":
query_result = await naive_query(
@ -2573,6 +2588,7 @@ class LightRAG:
global_config,
hashing_kv=self.llm_response_cache,
system_prompt=system_prompt,
token_tracker=token_tracker,
)
elif param.mode == "bypass":
# Bypass mode: directly use LLM without knowledge retrieval
@ -2581,19 +2597,28 @@ class LightRAG:
use_llm_func = partial(use_llm_func, _priority=8)
param.stream = True if param.stream is None else param.stream
response = await use_llm_func(
response = await call_llm_with_tracker(
use_llm_func,
query.strip(),
system_prompt=system_prompt,
history_messages=param.conversation_history,
enable_cot=True,
stream=param.stream,
token_tracker=token_tracker,
)
usage = token_tracker.get_usage()
usage_metadata = {"token_usage": usage}
if type(response) is str:
logger.info(
"[aquery_llm] Token usage (mode=%s, bypass=True, stream=False): %s",
param.mode,
usage,
)
return {
"status": "success",
"message": "Bypass mode LLM non streaming response",
"data": {},
"metadata": {},
"metadata": usage_metadata,
"llm_response": {
"content": response,
"response_iterator": None,
@ -2601,11 +2626,16 @@ class LightRAG:
},
}
else:
logger.info(
"[aquery_llm] Token usage (mode=%s, bypass=True, stream=True): %s",
param.mode,
usage,
)
return {
"status": "success",
"message": "Bypass mode LLM streaming response",
"data": {},
"metadata": {},
"metadata": usage_metadata,
"llm_response": {
"content": None,
"response_iterator": response,
@ -2616,17 +2646,26 @@ class LightRAG:
raise ValueError(f"Unknown mode {param.mode}")
await self._query_done()
usage = token_tracker.get_usage()
# Check if query_result is None
if query_result is None:
failure_metadata = {
"failure_reason": "no_results",
"mode": param.mode,
"token_usage": usage,
}
logger.info(
"[aquery_llm] Token usage (mode=%s, status=failure, stream=%s): %s",
param.mode,
bool(param.stream),
usage,
)
return {
"status": "failure",
"message": "Query returned no results",
"data": {},
"metadata": {
"failure_reason": "no_results",
"mode": param.mode,
},
"metadata": failure_metadata,
"llm_response": {
"content": PROMPTS["fail_response"],
"response_iterator": None,
@ -2636,6 +2675,8 @@ class LightRAG:
# Extract structured data from query result
raw_data = query_result.raw_data or {}
metadata = raw_data.setdefault("metadata", {})
metadata["token_usage"] = usage
raw_data["llm_response"] = {
"content": query_result.content
if not query_result.is_streaming
@ -2645,17 +2686,30 @@ class LightRAG:
else None,
"is_streaming": query_result.is_streaming,
}
logger.info(
"[aquery_llm] Token usage (mode=%s, status=success, stream=%s): %s",
param.mode,
query_result.is_streaming,
usage,
)
return raw_data
except Exception as e:
logger.error(f"Query failed: {e}")
# Return error response
usage = token_tracker.get_usage()
logger.info(
"[aquery_llm] Token usage (mode=%s, status=error, stream=%s): %s",
param.mode,
bool(param.stream),
usage,
)
return {
"status": "failure",
"message": f"Query failed: {str(e)}",
"data": {},
"metadata": {},
"metadata": {"token_usage": usage},
"llm_response": {
"content": None,
"response_iterator": None,

View file

@ -11,6 +11,7 @@ from lightrag.utils import (
logger,
compute_mdhash_id,
Tokenizer,
TokenTracker,
is_float_regex,
sanitize_and_normalize_extracted_text,
pack_user_ass_to_openai_messages,
@ -34,6 +35,7 @@ from lightrag.utils import (
apply_source_ids_limit,
merge_source_ids,
make_relation_chunk_key,
call_llm_with_tracker,
)
from lightrag.base import (
BaseGraphStorage,
@ -2767,6 +2769,7 @@ async def kg_query(
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
chunks_vdb: BaseVectorStorage = None,
token_tracker: TokenTracker | None = None,
) -> QueryResult | None:
"""
Execute knowledge graph query and return unified QueryResult object.
@ -2781,6 +2784,7 @@ async def kg_query(
global_config: Global configuration
hashing_kv: Cache storage
system_prompt: System prompt
token_tracker: Optional TokenTracker for aggregating token usage
chunks_vdb: Document chunks vector database
Returns:
@ -2809,9 +2813,20 @@ async def kg_query(
use_model_func = partial(use_model_func, _priority=5)
hl_keywords, ll_keywords = await get_keywords_from_query(
query, query_param, global_config, hashing_kv
query,
query_param,
global_config,
hashing_kv,
token_tracker=token_tracker,
)
def _attach_token_usage(raw_data: dict[str, Any] | None) -> dict[str, Any] | None:
if raw_data is None or token_tracker is None:
return raw_data
metadata = raw_data.setdefault("metadata", {})
metadata["token_usage"] = token_tracker.get_usage()
return raw_data
logger.debug(f"High-level keywords: {hl_keywords}")
logger.debug(f"Low-level keywords: {ll_keywords}")
@ -2849,9 +2864,8 @@ async def kg_query(
# Return different content based on query parameters
if query_param.only_need_context and not query_param.only_need_prompt:
return QueryResult(
content=context_result.context, raw_data=context_result.raw_data
)
raw_data = _attach_token_usage(context_result.raw_data)
return QueryResult(content=context_result.context, raw_data=raw_data)
user_prompt = f"\n\n{query_param.user_prompt}" if query_param.user_prompt else "n/a"
response_type = (
@ -2872,7 +2886,8 @@ async def kg_query(
if query_param.only_need_prompt:
prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
return QueryResult(content=prompt_content, raw_data=context_result.raw_data)
raw_data = _attach_token_usage(context_result.raw_data)
return QueryResult(content=prompt_content, raw_data=raw_data)
# Call LLM
tokenizer: Tokenizer = global_config["tokenizer"]
@ -2908,12 +2923,14 @@ async def kg_query(
)
response = cached_response
else:
response = await use_model_func(
response = await call_llm_with_tracker(
use_model_func,
user_query,
system_prompt=sys_prompt,
history_messages=query_param.conversation_history,
enable_cot=True,
stream=query_param.stream,
token_tracker=token_tracker,
)
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
@ -2956,12 +2973,14 @@ async def kg_query(
.strip()
)
return QueryResult(content=response, raw_data=context_result.raw_data)
raw_data = _attach_token_usage(context_result.raw_data)
return QueryResult(content=response, raw_data=raw_data)
else:
# Streaming response (AsyncIterator)
raw_data = _attach_token_usage(context_result.raw_data)
return QueryResult(
response_iterator=response,
raw_data=context_result.raw_data,
raw_data=raw_data,
is_streaming=True,
)
@ -2971,6 +2990,7 @@ async def get_keywords_from_query(
query_param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
token_tracker: TokenTracker | None = None,
) -> tuple[list[str], list[str]]:
"""
Retrieves high-level and low-level keywords for RAG operations.
@ -2983,6 +3003,7 @@ async def get_keywords_from_query(
query_param: Query parameters that may contain pre-defined keywords
global_config: Global configuration dictionary
hashing_kv: Optional key-value storage for caching results
token_tracker: Optional TokenTracker to aggregate keyword extraction usage
Returns:
A tuple containing (high_level_keywords, low_level_keywords)
@ -2993,7 +3014,11 @@ async def get_keywords_from_query(
# Extract keywords using extract_keywords_only function which already supports conversation history
hl_keywords, ll_keywords = await extract_keywords_only(
query, query_param, global_config, hashing_kv
query,
query_param,
global_config,
hashing_kv,
token_tracker=token_tracker,
)
return hl_keywords, ll_keywords
@ -3003,11 +3028,19 @@ async def extract_keywords_only(
param: QueryParam,
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
token_tracker: TokenTracker | None = None,
) -> tuple[list[str], list[str]]:
"""
Extract high-level and low-level keywords from the given 'text' using the LLM.
This method does NOT build the final RAG context or provide a final answer.
It ONLY extracts keywords (hl_keywords, ll_keywords).
Args:
text: Input text to analyze for keywords.
param: Query parameters controlling keyword extraction.
global_config: Global configuration dictionary.
hashing_kv: Optional cache storage to reuse previous keyword calls.
token_tracker: Optional TokenTracker to aggregate keyword extraction usage.
"""
# 1. Handle cache if needed - add cache type for keywords
@ -3056,7 +3089,12 @@ async def extract_keywords_only(
# Apply higher priority (5) to query relation LLM function
use_model_func = partial(use_model_func, _priority=5)
result = await use_model_func(kw_prompt, keyword_extraction=True)
result = await call_llm_with_tracker(
use_model_func,
kw_prompt,
keyword_extraction=True,
token_tracker=token_tracker,
)
# 5. Parse out JSON from the LLM response
result = remove_think_tags(result)
@ -4489,6 +4527,7 @@ async def naive_query(
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
token_tracker: TokenTracker | None = None,
return_raw_data: Literal[True] = True,
) -> dict[str, Any]: ...
@ -4501,6 +4540,7 @@ async def naive_query(
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
token_tracker: TokenTracker | None = None,
return_raw_data: Literal[False] = False,
) -> str | AsyncIterator[str]: ...
@ -4512,6 +4552,7 @@ async def naive_query(
global_config: dict[str, str],
hashing_kv: BaseKVStorage | None = None,
system_prompt: str | None = None,
token_tracker: TokenTracker | None = None,
) -> QueryResult | None:
"""
Execute naive query and return unified QueryResult object.
@ -4633,6 +4674,8 @@ async def naive_query(
"total_chunks_found": len(chunks),
"final_chunks_count": len(processed_chunks_with_ref_ids),
}
if token_tracker is not None:
raw_data["metadata"]["token_usage"] = token_tracker.get_usage()
# Build text_units_context from processed chunks with reference IDs
text_units_context = []
@ -4697,12 +4740,14 @@ async def naive_query(
)
response = cached_response
else:
response = await use_model_func(
response = await call_llm_with_tracker(
use_model_func,
user_query,
system_prompt=sys_prompt,
history_messages=query_param.conversation_history,
enable_cot=True,
stream=query_param.stream,
token_tracker=token_tracker,
)
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):

View file

@ -2282,6 +2282,34 @@ class TokenTracker:
f"Completion tokens: {usage['completion_tokens']}, "
f"Total tokens: {usage['total_tokens']}"
)
async def call_llm_with_tracker(
func: Callable[..., Any],
*args: Any,
token_tracker: "TokenTracker | None" = None,
**kwargs: Any,
) -> Any:
"""
Invoke an async LLM callable while optionally passing a TokenTracker.
If the underlying callable doesn't accept the `token_tracker` kwarg,
the call is retried without it for backward compatibility.
"""
call_kwargs = dict(kwargs)
if token_tracker is not None:
call_kwargs["token_tracker"] = token_tracker
try:
return await func(*args, **call_kwargs)
except TypeError as err:
if token_tracker is not None and "token_tracker" in str(err):
logger.debug(
"LLM callable %s does not accept token_tracker, retrying without it.",
getattr(func, "__name__", repr(func)),
)
call_kwargs.pop("token_tracker", None)
return await func(*args, **call_kwargs)
raise
async def apply_rerank_if_enabled(