diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index ff9ce8b0..3b82d718 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -91,6 +91,7 @@ from lightrag.operate import ( from lightrag.constants import GRAPH_FIELD_SEP from lightrag.utils import ( Tokenizer, + TokenTracker, TiktokenTokenizer, EmbeddingFunc, always_get_an_event_loop, @@ -106,6 +107,7 @@ from lightrag.utils import ( subtract_source_ids, make_relation_chunk_key, normalize_source_ids_limit_method, + call_llm_with_tracker, ) from lightrag.types import KnowledgeGraph from dotenv import load_dotenv @@ -2430,6 +2432,7 @@ class LightRAG: fields at the top level. """ global_config = asdict(self) + token_tracker = TokenTracker() # Create a copy of param to avoid modifying the original data_param = QueryParam( @@ -2467,6 +2470,7 @@ class LightRAG: hashing_kv=self.llm_response_cache, system_prompt=None, chunks_vdb=self.chunks_vdb, + token_tracker=token_tracker, ) elif data_param.mode == "naive": logger.debug(f"[aquery_data] Using naive_query for mode: {data_param.mode}") @@ -2477,6 +2481,7 @@ class LightRAG: global_config, hashing_kv=self.llm_response_cache, system_prompt=None, + token_tracker=token_tracker, ) elif data_param.mode == "bypass": logger.debug("[aquery_data] Using bypass mode") @@ -2522,6 +2527,14 @@ class LightRAG: else: logger.warning("[aquery_data] No data section found in query result") + usage = token_tracker.get_usage() + if isinstance(final_data, dict): + final_data.setdefault("metadata", {}) + final_data["metadata"]["token_usage"] = usage + logger.info( + "[aquery_data] Token usage (mode=%s): %s", data_param.mode, usage + ) + await self._query_done() return final_data @@ -2548,6 +2561,7 @@ class LightRAG: logger.debug(f"[aquery_llm] Query param: {param}") global_config = asdict(self) + token_tracker = TokenTracker() try: query_result = None @@ -2564,6 +2578,7 @@ class LightRAG: hashing_kv=self.llm_response_cache, system_prompt=system_prompt, chunks_vdb=self.chunks_vdb, + token_tracker=token_tracker, ) elif param.mode == "naive": query_result = await naive_query( @@ -2573,6 +2588,7 @@ class LightRAG: global_config, hashing_kv=self.llm_response_cache, system_prompt=system_prompt, + token_tracker=token_tracker, ) elif param.mode == "bypass": # Bypass mode: directly use LLM without knowledge retrieval @@ -2581,19 +2597,28 @@ class LightRAG: use_llm_func = partial(use_llm_func, _priority=8) param.stream = True if param.stream is None else param.stream - response = await use_llm_func( + response = await call_llm_with_tracker( + use_llm_func, query.strip(), system_prompt=system_prompt, history_messages=param.conversation_history, enable_cot=True, stream=param.stream, + token_tracker=token_tracker, ) + usage = token_tracker.get_usage() + usage_metadata = {"token_usage": usage} if type(response) is str: + logger.info( + "[aquery_llm] Token usage (mode=%s, bypass=True, stream=False): %s", + param.mode, + usage, + ) return { "status": "success", "message": "Bypass mode LLM non streaming response", "data": {}, - "metadata": {}, + "metadata": usage_metadata, "llm_response": { "content": response, "response_iterator": None, @@ -2601,11 +2626,16 @@ class LightRAG: }, } else: + logger.info( + "[aquery_llm] Token usage (mode=%s, bypass=True, stream=True): %s", + param.mode, + usage, + ) return { "status": "success", "message": "Bypass mode LLM streaming response", "data": {}, - "metadata": {}, + "metadata": usage_metadata, "llm_response": { "content": None, "response_iterator": response, @@ -2616,17 +2646,26 @@ class LightRAG: raise ValueError(f"Unknown mode {param.mode}") await self._query_done() + usage = token_tracker.get_usage() # Check if query_result is None if query_result is None: + failure_metadata = { + "failure_reason": "no_results", + "mode": param.mode, + "token_usage": usage, + } + logger.info( + "[aquery_llm] Token usage (mode=%s, status=failure, stream=%s): %s", + param.mode, + bool(param.stream), + usage, + ) return { "status": "failure", "message": "Query returned no results", "data": {}, - "metadata": { - "failure_reason": "no_results", - "mode": param.mode, - }, + "metadata": failure_metadata, "llm_response": { "content": PROMPTS["fail_response"], "response_iterator": None, @@ -2636,6 +2675,8 @@ class LightRAG: # Extract structured data from query result raw_data = query_result.raw_data or {} + metadata = raw_data.setdefault("metadata", {}) + metadata["token_usage"] = usage raw_data["llm_response"] = { "content": query_result.content if not query_result.is_streaming @@ -2645,17 +2686,30 @@ class LightRAG: else None, "is_streaming": query_result.is_streaming, } + logger.info( + "[aquery_llm] Token usage (mode=%s, status=success, stream=%s): %s", + param.mode, + query_result.is_streaming, + usage, + ) return raw_data except Exception as e: logger.error(f"Query failed: {e}") # Return error response + usage = token_tracker.get_usage() + logger.info( + "[aquery_llm] Token usage (mode=%s, status=error, stream=%s): %s", + param.mode, + bool(param.stream), + usage, + ) return { "status": "failure", "message": f"Query failed: {str(e)}", "data": {}, - "metadata": {}, + "metadata": {"token_usage": usage}, "llm_response": { "content": None, "response_iterator": None, diff --git a/lightrag/operate.py b/lightrag/operate.py index 8ecec587..7be89c90 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -11,6 +11,7 @@ from lightrag.utils import ( logger, compute_mdhash_id, Tokenizer, + TokenTracker, is_float_regex, sanitize_and_normalize_extracted_text, pack_user_ass_to_openai_messages, @@ -34,6 +35,7 @@ from lightrag.utils import ( apply_source_ids_limit, merge_source_ids, make_relation_chunk_key, + call_llm_with_tracker, ) from lightrag.base import ( BaseGraphStorage, @@ -2767,6 +2769,7 @@ async def kg_query( hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, chunks_vdb: BaseVectorStorage = None, + token_tracker: TokenTracker | None = None, ) -> QueryResult | None: """ Execute knowledge graph query and return unified QueryResult object. @@ -2781,6 +2784,7 @@ async def kg_query( global_config: Global configuration hashing_kv: Cache storage system_prompt: System prompt + token_tracker: Optional TokenTracker for aggregating token usage chunks_vdb: Document chunks vector database Returns: @@ -2809,9 +2813,20 @@ async def kg_query( use_model_func = partial(use_model_func, _priority=5) hl_keywords, ll_keywords = await get_keywords_from_query( - query, query_param, global_config, hashing_kv + query, + query_param, + global_config, + hashing_kv, + token_tracker=token_tracker, ) + def _attach_token_usage(raw_data: dict[str, Any] | None) -> dict[str, Any] | None: + if raw_data is None or token_tracker is None: + return raw_data + metadata = raw_data.setdefault("metadata", {}) + metadata["token_usage"] = token_tracker.get_usage() + return raw_data + logger.debug(f"High-level keywords: {hl_keywords}") logger.debug(f"Low-level keywords: {ll_keywords}") @@ -2849,9 +2864,8 @@ async def kg_query( # Return different content based on query parameters if query_param.only_need_context and not query_param.only_need_prompt: - return QueryResult( - content=context_result.context, raw_data=context_result.raw_data - ) + raw_data = _attach_token_usage(context_result.raw_data) + return QueryResult(content=context_result.context, raw_data=raw_data) user_prompt = f"\n\n{query_param.user_prompt}" if query_param.user_prompt else "n/a" response_type = ( @@ -2872,7 +2886,8 @@ async def kg_query( if query_param.only_need_prompt: prompt_content = "\n\n".join([sys_prompt, "---User Query---", user_query]) - return QueryResult(content=prompt_content, raw_data=context_result.raw_data) + raw_data = _attach_token_usage(context_result.raw_data) + return QueryResult(content=prompt_content, raw_data=raw_data) # Call LLM tokenizer: Tokenizer = global_config["tokenizer"] @@ -2908,12 +2923,14 @@ async def kg_query( ) response = cached_response else: - response = await use_model_func( + response = await call_llm_with_tracker( + use_model_func, user_query, system_prompt=sys_prompt, history_messages=query_param.conversation_history, enable_cot=True, stream=query_param.stream, + token_tracker=token_tracker, ) if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): @@ -2956,12 +2973,14 @@ async def kg_query( .strip() ) - return QueryResult(content=response, raw_data=context_result.raw_data) + raw_data = _attach_token_usage(context_result.raw_data) + return QueryResult(content=response, raw_data=raw_data) else: # Streaming response (AsyncIterator) + raw_data = _attach_token_usage(context_result.raw_data) return QueryResult( response_iterator=response, - raw_data=context_result.raw_data, + raw_data=raw_data, is_streaming=True, ) @@ -2971,6 +2990,7 @@ async def get_keywords_from_query( query_param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, + token_tracker: TokenTracker | None = None, ) -> tuple[list[str], list[str]]: """ Retrieves high-level and low-level keywords for RAG operations. @@ -2983,6 +3003,7 @@ async def get_keywords_from_query( query_param: Query parameters that may contain pre-defined keywords global_config: Global configuration dictionary hashing_kv: Optional key-value storage for caching results + token_tracker: Optional TokenTracker to aggregate keyword extraction usage Returns: A tuple containing (high_level_keywords, low_level_keywords) @@ -2993,7 +3014,11 @@ async def get_keywords_from_query( # Extract keywords using extract_keywords_only function which already supports conversation history hl_keywords, ll_keywords = await extract_keywords_only( - query, query_param, global_config, hashing_kv + query, + query_param, + global_config, + hashing_kv, + token_tracker=token_tracker, ) return hl_keywords, ll_keywords @@ -3003,11 +3028,19 @@ async def extract_keywords_only( param: QueryParam, global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, + token_tracker: TokenTracker | None = None, ) -> tuple[list[str], list[str]]: """ Extract high-level and low-level keywords from the given 'text' using the LLM. This method does NOT build the final RAG context or provide a final answer. It ONLY extracts keywords (hl_keywords, ll_keywords). + + Args: + text: Input text to analyze for keywords. + param: Query parameters controlling keyword extraction. + global_config: Global configuration dictionary. + hashing_kv: Optional cache storage to reuse previous keyword calls. + token_tracker: Optional TokenTracker to aggregate keyword extraction usage. """ # 1. Handle cache if needed - add cache type for keywords @@ -3056,7 +3089,12 @@ async def extract_keywords_only( # Apply higher priority (5) to query relation LLM function use_model_func = partial(use_model_func, _priority=5) - result = await use_model_func(kw_prompt, keyword_extraction=True) + result = await call_llm_with_tracker( + use_model_func, + kw_prompt, + keyword_extraction=True, + token_tracker=token_tracker, + ) # 5. Parse out JSON from the LLM response result = remove_think_tags(result) @@ -4489,6 +4527,7 @@ async def naive_query( global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, + token_tracker: TokenTracker | None = None, return_raw_data: Literal[True] = True, ) -> dict[str, Any]: ... @@ -4501,6 +4540,7 @@ async def naive_query( global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, + token_tracker: TokenTracker | None = None, return_raw_data: Literal[False] = False, ) -> str | AsyncIterator[str]: ... @@ -4512,6 +4552,7 @@ async def naive_query( global_config: dict[str, str], hashing_kv: BaseKVStorage | None = None, system_prompt: str | None = None, + token_tracker: TokenTracker | None = None, ) -> QueryResult | None: """ Execute naive query and return unified QueryResult object. @@ -4633,6 +4674,8 @@ async def naive_query( "total_chunks_found": len(chunks), "final_chunks_count": len(processed_chunks_with_ref_ids), } + if token_tracker is not None: + raw_data["metadata"]["token_usage"] = token_tracker.get_usage() # Build text_units_context from processed chunks with reference IDs text_units_context = [] @@ -4697,12 +4740,14 @@ async def naive_query( ) response = cached_response else: - response = await use_model_func( + response = await call_llm_with_tracker( + use_model_func, user_query, system_prompt=sys_prompt, history_messages=query_param.conversation_history, enable_cot=True, stream=query_param.stream, + token_tracker=token_tracker, ) if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"): diff --git a/lightrag/utils.py b/lightrag/utils.py index bfa3cac4..a54d0f58 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -2282,6 +2282,34 @@ class TokenTracker: f"Completion tokens: {usage['completion_tokens']}, " f"Total tokens: {usage['total_tokens']}" ) + +async def call_llm_with_tracker( + func: Callable[..., Any], + *args: Any, + token_tracker: "TokenTracker | None" = None, + **kwargs: Any, + ) -> Any: + """ + Invoke an async LLM callable while optionally passing a TokenTracker. + + If the underlying callable doesn't accept the `token_tracker` kwarg, + the call is retried without it for backward compatibility. + """ + call_kwargs = dict(kwargs) + if token_tracker is not None: + call_kwargs["token_tracker"] = token_tracker + + try: + return await func(*args, **call_kwargs) + except TypeError as err: + if token_tracker is not None and "token_tracker" in str(err): + logger.debug( + "LLM callable %s does not accept token_tracker, retrying without it.", + getattr(func, "__name__", repr(func)), + ) + call_kwargs.pop("token_tracker", None) + return await func(*args, **call_kwargs) + raise async def apply_rerank_if_enabled(