Add TokenTracker and update LLM call functions to track token usage
This commit is contained in:
parent
7dfc224be9
commit
646ed12c13
3 changed files with 146 additions and 19 deletions
|
|
@ -91,6 +91,7 @@ from lightrag.operate import (
|
|||
from lightrag.constants import GRAPH_FIELD_SEP
|
||||
from lightrag.utils import (
|
||||
Tokenizer,
|
||||
TokenTracker,
|
||||
TiktokenTokenizer,
|
||||
EmbeddingFunc,
|
||||
always_get_an_event_loop,
|
||||
|
|
@ -106,6 +107,7 @@ from lightrag.utils import (
|
|||
subtract_source_ids,
|
||||
make_relation_chunk_key,
|
||||
normalize_source_ids_limit_method,
|
||||
call_llm_with_tracker,
|
||||
)
|
||||
from lightrag.types import KnowledgeGraph
|
||||
from dotenv import load_dotenv
|
||||
|
|
@ -2430,6 +2432,7 @@ class LightRAG:
|
|||
fields at the top level.
|
||||
"""
|
||||
global_config = asdict(self)
|
||||
token_tracker = TokenTracker()
|
||||
|
||||
# Create a copy of param to avoid modifying the original
|
||||
data_param = QueryParam(
|
||||
|
|
@ -2467,6 +2470,7 @@ class LightRAG:
|
|||
hashing_kv=self.llm_response_cache,
|
||||
system_prompt=None,
|
||||
chunks_vdb=self.chunks_vdb,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
elif data_param.mode == "naive":
|
||||
logger.debug(f"[aquery_data] Using naive_query for mode: {data_param.mode}")
|
||||
|
|
@ -2477,6 +2481,7 @@ class LightRAG:
|
|||
global_config,
|
||||
hashing_kv=self.llm_response_cache,
|
||||
system_prompt=None,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
elif data_param.mode == "bypass":
|
||||
logger.debug("[aquery_data] Using bypass mode")
|
||||
|
|
@ -2522,6 +2527,14 @@ class LightRAG:
|
|||
else:
|
||||
logger.warning("[aquery_data] No data section found in query result")
|
||||
|
||||
usage = token_tracker.get_usage()
|
||||
if isinstance(final_data, dict):
|
||||
final_data.setdefault("metadata", {})
|
||||
final_data["metadata"]["token_usage"] = usage
|
||||
logger.info(
|
||||
"[aquery_data] Token usage (mode=%s): %s", data_param.mode, usage
|
||||
)
|
||||
|
||||
await self._query_done()
|
||||
return final_data
|
||||
|
||||
|
|
@ -2548,6 +2561,7 @@ class LightRAG:
|
|||
logger.debug(f"[aquery_llm] Query param: {param}")
|
||||
|
||||
global_config = asdict(self)
|
||||
token_tracker = TokenTracker()
|
||||
|
||||
try:
|
||||
query_result = None
|
||||
|
|
@ -2564,6 +2578,7 @@ class LightRAG:
|
|||
hashing_kv=self.llm_response_cache,
|
||||
system_prompt=system_prompt,
|
||||
chunks_vdb=self.chunks_vdb,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
elif param.mode == "naive":
|
||||
query_result = await naive_query(
|
||||
|
|
@ -2573,6 +2588,7 @@ class LightRAG:
|
|||
global_config,
|
||||
hashing_kv=self.llm_response_cache,
|
||||
system_prompt=system_prompt,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
elif param.mode == "bypass":
|
||||
# Bypass mode: directly use LLM without knowledge retrieval
|
||||
|
|
@ -2581,19 +2597,28 @@ class LightRAG:
|
|||
use_llm_func = partial(use_llm_func, _priority=8)
|
||||
|
||||
param.stream = True if param.stream is None else param.stream
|
||||
response = await use_llm_func(
|
||||
response = await call_llm_with_tracker(
|
||||
use_llm_func,
|
||||
query.strip(),
|
||||
system_prompt=system_prompt,
|
||||
history_messages=param.conversation_history,
|
||||
enable_cot=True,
|
||||
stream=param.stream,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
usage = token_tracker.get_usage()
|
||||
usage_metadata = {"token_usage": usage}
|
||||
if type(response) is str:
|
||||
logger.info(
|
||||
"[aquery_llm] Token usage (mode=%s, bypass=True, stream=False): %s",
|
||||
param.mode,
|
||||
usage,
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Bypass mode LLM non streaming response",
|
||||
"data": {},
|
||||
"metadata": {},
|
||||
"metadata": usage_metadata,
|
||||
"llm_response": {
|
||||
"content": response,
|
||||
"response_iterator": None,
|
||||
|
|
@ -2601,11 +2626,16 @@ class LightRAG:
|
|||
},
|
||||
}
|
||||
else:
|
||||
logger.info(
|
||||
"[aquery_llm] Token usage (mode=%s, bypass=True, stream=True): %s",
|
||||
param.mode,
|
||||
usage,
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Bypass mode LLM streaming response",
|
||||
"data": {},
|
||||
"metadata": {},
|
||||
"metadata": usage_metadata,
|
||||
"llm_response": {
|
||||
"content": None,
|
||||
"response_iterator": response,
|
||||
|
|
@ -2616,17 +2646,26 @@ class LightRAG:
|
|||
raise ValueError(f"Unknown mode {param.mode}")
|
||||
|
||||
await self._query_done()
|
||||
usage = token_tracker.get_usage()
|
||||
|
||||
# Check if query_result is None
|
||||
if query_result is None:
|
||||
failure_metadata = {
|
||||
"failure_reason": "no_results",
|
||||
"mode": param.mode,
|
||||
"token_usage": usage,
|
||||
}
|
||||
logger.info(
|
||||
"[aquery_llm] Token usage (mode=%s, status=failure, stream=%s): %s",
|
||||
param.mode,
|
||||
bool(param.stream),
|
||||
usage,
|
||||
)
|
||||
return {
|
||||
"status": "failure",
|
||||
"message": "Query returned no results",
|
||||
"data": {},
|
||||
"metadata": {
|
||||
"failure_reason": "no_results",
|
||||
"mode": param.mode,
|
||||
},
|
||||
"metadata": failure_metadata,
|
||||
"llm_response": {
|
||||
"content": PROMPTS["fail_response"],
|
||||
"response_iterator": None,
|
||||
|
|
@ -2636,6 +2675,8 @@ class LightRAG:
|
|||
|
||||
# Extract structured data from query result
|
||||
raw_data = query_result.raw_data or {}
|
||||
metadata = raw_data.setdefault("metadata", {})
|
||||
metadata["token_usage"] = usage
|
||||
raw_data["llm_response"] = {
|
||||
"content": query_result.content
|
||||
if not query_result.is_streaming
|
||||
|
|
@ -2645,17 +2686,30 @@ class LightRAG:
|
|||
else None,
|
||||
"is_streaming": query_result.is_streaming,
|
||||
}
|
||||
logger.info(
|
||||
"[aquery_llm] Token usage (mode=%s, status=success, stream=%s): %s",
|
||||
param.mode,
|
||||
query_result.is_streaming,
|
||||
usage,
|
||||
)
|
||||
|
||||
return raw_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Query failed: {e}")
|
||||
# Return error response
|
||||
usage = token_tracker.get_usage()
|
||||
logger.info(
|
||||
"[aquery_llm] Token usage (mode=%s, status=error, stream=%s): %s",
|
||||
param.mode,
|
||||
bool(param.stream),
|
||||
usage,
|
||||
)
|
||||
return {
|
||||
"status": "failure",
|
||||
"message": f"Query failed: {str(e)}",
|
||||
"data": {},
|
||||
"metadata": {},
|
||||
"metadata": {"token_usage": usage},
|
||||
"llm_response": {
|
||||
"content": None,
|
||||
"response_iterator": None,
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from lightrag.utils import (
|
|||
logger,
|
||||
compute_mdhash_id,
|
||||
Tokenizer,
|
||||
TokenTracker,
|
||||
is_float_regex,
|
||||
sanitize_and_normalize_extracted_text,
|
||||
pack_user_ass_to_openai_messages,
|
||||
|
|
@ -34,6 +35,7 @@ from lightrag.utils import (
|
|||
apply_source_ids_limit,
|
||||
merge_source_ids,
|
||||
make_relation_chunk_key,
|
||||
call_llm_with_tracker,
|
||||
)
|
||||
from lightrag.base import (
|
||||
BaseGraphStorage,
|
||||
|
|
@ -2767,6 +2769,7 @@ async def kg_query(
|
|||
hashing_kv: BaseKVStorage | None = None,
|
||||
system_prompt: str | None = None,
|
||||
chunks_vdb: BaseVectorStorage = None,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> QueryResult | None:
|
||||
"""
|
||||
Execute knowledge graph query and return unified QueryResult object.
|
||||
|
|
@ -2781,6 +2784,7 @@ async def kg_query(
|
|||
global_config: Global configuration
|
||||
hashing_kv: Cache storage
|
||||
system_prompt: System prompt
|
||||
token_tracker: Optional TokenTracker for aggregating token usage
|
||||
chunks_vdb: Document chunks vector database
|
||||
|
||||
Returns:
|
||||
|
|
@ -2809,9 +2813,20 @@ async def kg_query(
|
|||
use_model_func = partial(use_model_func, _priority=5)
|
||||
|
||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||||
query, query_param, global_config, hashing_kv
|
||||
query,
|
||||
query_param,
|
||||
global_config,
|
||||
hashing_kv,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
|
||||
def _attach_token_usage(raw_data: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
if raw_data is None or token_tracker is None:
|
||||
return raw_data
|
||||
metadata = raw_data.setdefault("metadata", {})
|
||||
metadata["token_usage"] = token_tracker.get_usage()
|
||||
return raw_data
|
||||
|
||||
logger.debug(f"High-level keywords: {hl_keywords}")
|
||||
logger.debug(f"Low-level keywords: {ll_keywords}")
|
||||
|
||||
|
|
@ -2849,9 +2864,8 @@ async def kg_query(
|
|||
|
||||
# Return different content based on query parameters
|
||||
if query_param.only_need_context and not query_param.only_need_prompt:
|
||||
return QueryResult(
|
||||
content=context_result.context, raw_data=context_result.raw_data
|
||||
)
|
||||
raw_data = _attach_token_usage(context_result.raw_data)
|
||||
return QueryResult(content=context_result.context, raw_data=raw_data)
|
||||
|
||||
user_prompt = f"\n\n{query_param.user_prompt}" if query_param.user_prompt else "n/a"
|
||||
response_type = (
|
||||
|
|
@ -2872,7 +2886,8 @@ async def kg_query(
|
|||
|
||||
if query_param.only_need_prompt:
|
||||
prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query])
|
||||
return QueryResult(content=prompt_content, raw_data=context_result.raw_data)
|
||||
raw_data = _attach_token_usage(context_result.raw_data)
|
||||
return QueryResult(content=prompt_content, raw_data=raw_data)
|
||||
|
||||
# Call LLM
|
||||
tokenizer: Tokenizer = global_config["tokenizer"]
|
||||
|
|
@ -2908,12 +2923,14 @@ async def kg_query(
|
|||
)
|
||||
response = cached_response
|
||||
else:
|
||||
response = await use_model_func(
|
||||
response = await call_llm_with_tracker(
|
||||
use_model_func,
|
||||
user_query,
|
||||
system_prompt=sys_prompt,
|
||||
history_messages=query_param.conversation_history,
|
||||
enable_cot=True,
|
||||
stream=query_param.stream,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
|
||||
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
|
||||
|
|
@ -2956,12 +2973,14 @@ async def kg_query(
|
|||
.strip()
|
||||
)
|
||||
|
||||
return QueryResult(content=response, raw_data=context_result.raw_data)
|
||||
raw_data = _attach_token_usage(context_result.raw_data)
|
||||
return QueryResult(content=response, raw_data=raw_data)
|
||||
else:
|
||||
# Streaming response (AsyncIterator)
|
||||
raw_data = _attach_token_usage(context_result.raw_data)
|
||||
return QueryResult(
|
||||
response_iterator=response,
|
||||
raw_data=context_result.raw_data,
|
||||
raw_data=raw_data,
|
||||
is_streaming=True,
|
||||
)
|
||||
|
||||
|
|
@ -2971,6 +2990,7 @@ async def get_keywords_from_query(
|
|||
query_param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Retrieves high-level and low-level keywords for RAG operations.
|
||||
|
|
@ -2983,6 +3003,7 @@ async def get_keywords_from_query(
|
|||
query_param: Query parameters that may contain pre-defined keywords
|
||||
global_config: Global configuration dictionary
|
||||
hashing_kv: Optional key-value storage for caching results
|
||||
token_tracker: Optional TokenTracker to aggregate keyword extraction usage
|
||||
|
||||
Returns:
|
||||
A tuple containing (high_level_keywords, low_level_keywords)
|
||||
|
|
@ -2993,7 +3014,11 @@ async def get_keywords_from_query(
|
|||
|
||||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||
query, query_param, global_config, hashing_kv
|
||||
query,
|
||||
query_param,
|
||||
global_config,
|
||||
hashing_kv,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
return hl_keywords, ll_keywords
|
||||
|
||||
|
|
@ -3003,11 +3028,19 @@ async def extract_keywords_only(
|
|||
param: QueryParam,
|
||||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""
|
||||
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
||||
This method does NOT build the final RAG context or provide a final answer.
|
||||
It ONLY extracts keywords (hl_keywords, ll_keywords).
|
||||
|
||||
Args:
|
||||
text: Input text to analyze for keywords.
|
||||
param: Query parameters controlling keyword extraction.
|
||||
global_config: Global configuration dictionary.
|
||||
hashing_kv: Optional cache storage to reuse previous keyword calls.
|
||||
token_tracker: Optional TokenTracker to aggregate keyword extraction usage.
|
||||
"""
|
||||
|
||||
# 1. Handle cache if needed - add cache type for keywords
|
||||
|
|
@ -3056,7 +3089,12 @@ async def extract_keywords_only(
|
|||
# Apply higher priority (5) to query relation LLM function
|
||||
use_model_func = partial(use_model_func, _priority=5)
|
||||
|
||||
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
||||
result = await call_llm_with_tracker(
|
||||
use_model_func,
|
||||
kw_prompt,
|
||||
keyword_extraction=True,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
|
||||
# 5. Parse out JSON from the LLM response
|
||||
result = remove_think_tags(result)
|
||||
|
|
@ -4489,6 +4527,7 @@ async def naive_query(
|
|||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
system_prompt: str | None = None,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
return_raw_data: Literal[True] = True,
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
|
|
@ -4501,6 +4540,7 @@ async def naive_query(
|
|||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
system_prompt: str | None = None,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
return_raw_data: Literal[False] = False,
|
||||
) -> str | AsyncIterator[str]: ...
|
||||
|
||||
|
|
@ -4512,6 +4552,7 @@ async def naive_query(
|
|||
global_config: dict[str, str],
|
||||
hashing_kv: BaseKVStorage | None = None,
|
||||
system_prompt: str | None = None,
|
||||
token_tracker: TokenTracker | None = None,
|
||||
) -> QueryResult | None:
|
||||
"""
|
||||
Execute naive query and return unified QueryResult object.
|
||||
|
|
@ -4633,6 +4674,8 @@ async def naive_query(
|
|||
"total_chunks_found": len(chunks),
|
||||
"final_chunks_count": len(processed_chunks_with_ref_ids),
|
||||
}
|
||||
if token_tracker is not None:
|
||||
raw_data["metadata"]["token_usage"] = token_tracker.get_usage()
|
||||
|
||||
# Build text_units_context from processed chunks with reference IDs
|
||||
text_units_context = []
|
||||
|
|
@ -4697,12 +4740,14 @@ async def naive_query(
|
|||
)
|
||||
response = cached_response
|
||||
else:
|
||||
response = await use_model_func(
|
||||
response = await call_llm_with_tracker(
|
||||
use_model_func,
|
||||
user_query,
|
||||
system_prompt=sys_prompt,
|
||||
history_messages=query_param.conversation_history,
|
||||
enable_cot=True,
|
||||
stream=query_param.stream,
|
||||
token_tracker=token_tracker,
|
||||
)
|
||||
|
||||
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
|
||||
|
|
|
|||
|
|
@ -2282,6 +2282,34 @@ class TokenTracker:
|
|||
f"Completion tokens: {usage['completion_tokens']}, "
|
||||
f"Total tokens: {usage['total_tokens']}"
|
||||
)
|
||||
|
||||
async def call_llm_with_tracker(
|
||||
func: Callable[..., Any],
|
||||
*args: Any,
|
||||
token_tracker: "TokenTracker | None" = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Invoke an async LLM callable while optionally passing a TokenTracker.
|
||||
|
||||
If the underlying callable doesn't accept the `token_tracker` kwarg,
|
||||
the call is retried without it for backward compatibility.
|
||||
"""
|
||||
call_kwargs = dict(kwargs)
|
||||
if token_tracker is not None:
|
||||
call_kwargs["token_tracker"] = token_tracker
|
||||
|
||||
try:
|
||||
return await func(*args, **call_kwargs)
|
||||
except TypeError as err:
|
||||
if token_tracker is not None and "token_tracker" in str(err):
|
||||
logger.debug(
|
||||
"LLM callable %s does not accept token_tracker, retrying without it.",
|
||||
getattr(func, "__name__", repr(func)),
|
||||
)
|
||||
call_kwargs.pop("token_tracker", None)
|
||||
return await func(*args, **call_kwargs)
|
||||
raise
|
||||
|
||||
|
||||
async def apply_rerank_if_enabled(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue