feat (token_tracking): added tracking token to both query and insert endpoints --and consequently pipeline
This commit is contained in:
parent
cd664de057
commit
3a2d3ddb9f
11 changed files with 588 additions and 104 deletions
|
|
@ -397,6 +397,15 @@ def parse_args() -> argparse.Namespace:
|
||||||
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
|
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Token tracking configuration
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-token-tracking",
|
||||||
|
action="store_true",
|
||||||
|
default=get_env_value("ENABLE_TOKEN_TRACKING", False, bool),
|
||||||
|
help="Enable token usage tracking for LLM calls (default: from env or False)",
|
||||||
|
)
|
||||||
|
args.enable_token_tracking = get_env_value("ENABLE_TOKEN_TRACKING", False, bool)
|
||||||
|
|
||||||
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
|
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
|
||||||
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag
|
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -301,7 +301,10 @@ def create_app(args):
|
||||||
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
def create_optimized_openai_llm_func(
|
def create_optimized_openai_llm_func(
|
||||||
config_cache: LLMConfigCache, args, llm_timeout: int
|
config_cache: LLMConfigCache,
|
||||||
|
args,
|
||||||
|
llm_timeout: int,
|
||||||
|
enable_token_tracking=False,
|
||||||
):
|
):
|
||||||
"""Create optimized OpenAI LLM function with pre-processed configuration"""
|
"""Create optimized OpenAI LLM function with pre-processed configuration"""
|
||||||
|
|
||||||
|
|
@ -332,13 +335,19 @@ def create_app(args):
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
base_url=args.llm_binding_host,
|
base_url=args.llm_binding_host,
|
||||||
api_key=args.llm_binding_api_key,
|
api_key=args.llm_binding_api_key,
|
||||||
|
token_tracker=getattr(app.state, "token_tracker", None)
|
||||||
|
if enable_token_tracking and hasattr(app.state, "token_tracker")
|
||||||
|
else None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return optimized_openai_alike_model_complete
|
return optimized_openai_alike_model_complete
|
||||||
|
|
||||||
def create_optimized_azure_openai_llm_func(
|
def create_optimized_azure_openai_llm_func(
|
||||||
config_cache: LLMConfigCache, args, llm_timeout: int
|
config_cache: LLMConfigCache,
|
||||||
|
args,
|
||||||
|
llm_timeout: int,
|
||||||
|
enable_token_tracking=False,
|
||||||
):
|
):
|
||||||
"""Create optimized Azure OpenAI LLM function with pre-processed configuration"""
|
"""Create optimized Azure OpenAI LLM function with pre-processed configuration"""
|
||||||
|
|
||||||
|
|
@ -359,8 +368,8 @@ def create_app(args):
|
||||||
|
|
||||||
# Use pre-processed configuration to avoid repeated parsing
|
# Use pre-processed configuration to avoid repeated parsing
|
||||||
kwargs["timeout"] = llm_timeout
|
kwargs["timeout"] = llm_timeout
|
||||||
if config_cache.openai_llm_options:
|
if config_cache.azure_openai_llm_options:
|
||||||
kwargs.update(config_cache.openai_llm_options)
|
kwargs.update(config_cache.azure_openai_llm_options)
|
||||||
|
|
||||||
return await azure_openai_complete_if_cache(
|
return await azure_openai_complete_if_cache(
|
||||||
args.llm_model,
|
args.llm_model,
|
||||||
|
|
@ -370,12 +379,15 @@ def create_app(args):
|
||||||
base_url=args.llm_binding_host,
|
base_url=args.llm_binding_host,
|
||||||
api_key=os.getenv("AZURE_OPENAI_API_KEY", args.llm_binding_api_key),
|
api_key=os.getenv("AZURE_OPENAI_API_KEY", args.llm_binding_api_key),
|
||||||
api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"),
|
api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"),
|
||||||
|
token_tracker=getattr(app.state, "token_tracker", None)
|
||||||
|
if enable_token_tracking and hasattr(app.state, "token_tracker")
|
||||||
|
else None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return optimized_azure_openai_model_complete
|
return optimized_azure_openai_model_complete
|
||||||
|
|
||||||
def create_llm_model_func(binding: str):
|
def create_llm_model_func(binding: str, enable_token_tracking=False):
|
||||||
"""
|
"""
|
||||||
Create LLM model function based on binding type.
|
Create LLM model function based on binding type.
|
||||||
Uses optimized functions for OpenAI bindings and lazy import for others.
|
Uses optimized functions for OpenAI bindings and lazy import for others.
|
||||||
|
|
@ -384,21 +396,42 @@ def create_app(args):
|
||||||
if binding == "lollms":
|
if binding == "lollms":
|
||||||
from lightrag.llm.lollms import lollms_model_complete
|
from lightrag.llm.lollms import lollms_model_complete
|
||||||
|
|
||||||
return lollms_model_complete
|
async def lollms_model_complete_with_tracker(*args, **kwargs):
|
||||||
|
# Add token tracker if enabled
|
||||||
|
if enable_token_tracking and hasattr(app.state, "token_tracker"):
|
||||||
|
kwargs["token_tracker"] = app.state.token_tracker
|
||||||
|
return await lollms_model_complete(*args, **kwargs)
|
||||||
|
|
||||||
|
return lollms_model_complete_with_tracker
|
||||||
elif binding == "ollama":
|
elif binding == "ollama":
|
||||||
from lightrag.llm.ollama import ollama_model_complete
|
from lightrag.llm.ollama import ollama_model_complete
|
||||||
|
|
||||||
return ollama_model_complete
|
async def ollama_model_complete_with_tracker(*args, **kwargs):
|
||||||
|
# Add token tracker if enabled
|
||||||
|
if enable_token_tracking and hasattr(app.state, "token_tracker"):
|
||||||
|
kwargs["token_tracker"] = app.state.token_tracker
|
||||||
|
return await ollama_model_complete(*args, **kwargs)
|
||||||
|
|
||||||
|
return ollama_model_complete_with_tracker
|
||||||
elif binding == "aws_bedrock":
|
elif binding == "aws_bedrock":
|
||||||
return bedrock_model_complete # Already defined locally
|
|
||||||
|
async def bedrock_model_complete_with_tracker(*args, **kwargs):
|
||||||
|
# Add token tracker if enabled
|
||||||
|
if enable_token_tracking and hasattr(app.state, "token_tracker"):
|
||||||
|
kwargs["token_tracker"] = app.state.token_tracker
|
||||||
|
return await bedrock_model_complete(*args, **kwargs)
|
||||||
|
|
||||||
|
return bedrock_model_complete_with_tracker
|
||||||
elif binding == "azure_openai":
|
elif binding == "azure_openai":
|
||||||
# Use optimized function with pre-processed configuration
|
# Use optimized function with pre-processed configuration
|
||||||
return create_optimized_azure_openai_llm_func(
|
return create_optimized_azure_openai_llm_func(
|
||||||
config_cache, args, llm_timeout
|
config_cache, args, llm_timeout, enable_token_tracking
|
||||||
)
|
)
|
||||||
else: # openai and compatible
|
else: # openai and compatible
|
||||||
# Use optimized function with pre-processed configuration
|
# Use optimized function with pre-processed configuration
|
||||||
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)
|
return create_optimized_openai_llm_func(
|
||||||
|
config_cache, args, llm_timeout, enable_token_tracking
|
||||||
|
)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise Exception(f"Failed to import {binding} LLM binding: {e}")
|
raise Exception(f"Failed to import {binding} LLM binding: {e}")
|
||||||
|
|
||||||
|
|
@ -422,7 +455,14 @@ def create_app(args):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def create_optimized_embedding_function(
|
def create_optimized_embedding_function(
|
||||||
config_cache: LLMConfigCache, binding, model, host, api_key, dimensions, args
|
config_cache: LLMConfigCache,
|
||||||
|
binding,
|
||||||
|
model,
|
||||||
|
host,
|
||||||
|
api_key,
|
||||||
|
dimensions,
|
||||||
|
args,
|
||||||
|
enable_token_tracking=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create optimized embedding function with pre-processed configuration for applicable bindings.
|
Create optimized embedding function with pre-processed configuration for applicable bindings.
|
||||||
|
|
@ -435,7 +475,13 @@ def create_app(args):
|
||||||
from lightrag.llm.lollms import lollms_embed
|
from lightrag.llm.lollms import lollms_embed
|
||||||
|
|
||||||
return await lollms_embed(
|
return await lollms_embed(
|
||||||
texts, embed_model=model, host=host, api_key=api_key
|
texts,
|
||||||
|
embed_model=model,
|
||||||
|
host=host,
|
||||||
|
api_key=api_key,
|
||||||
|
token_tracker=getattr(app.state, "token_tracker", None)
|
||||||
|
if enable_token_tracking and hasattr(app.state, "token_tracker")
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
elif binding == "ollama":
|
elif binding == "ollama":
|
||||||
from lightrag.llm.ollama import ollama_embed
|
from lightrag.llm.ollama import ollama_embed
|
||||||
|
|
@ -455,26 +501,54 @@ def create_app(args):
|
||||||
host=host,
|
host=host,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
options=ollama_options,
|
options=ollama_options,
|
||||||
|
token_tracker=getattr(app.state, "token_tracker", None)
|
||||||
|
if enable_token_tracking and hasattr(app.state, "token_tracker")
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
elif binding == "azure_openai":
|
elif binding == "azure_openai":
|
||||||
from lightrag.llm.azure_openai import azure_openai_embed
|
from lightrag.llm.azure_openai import azure_openai_embed
|
||||||
|
|
||||||
return await azure_openai_embed(texts, model=model, api_key=api_key)
|
return await azure_openai_embed(
|
||||||
|
texts,
|
||||||
|
model=model,
|
||||||
|
api_key=api_key,
|
||||||
|
token_tracker=getattr(app.state, "token_tracker", None)
|
||||||
|
if enable_token_tracking and hasattr(app.state, "token_tracker")
|
||||||
|
else None,
|
||||||
|
)
|
||||||
elif binding == "aws_bedrock":
|
elif binding == "aws_bedrock":
|
||||||
from lightrag.llm.bedrock import bedrock_embed
|
from lightrag.llm.bedrock import bedrock_embed
|
||||||
|
|
||||||
return await bedrock_embed(texts, model=model)
|
return await bedrock_embed(
|
||||||
|
texts,
|
||||||
|
model=model,
|
||||||
|
token_tracker=getattr(app.state, "token_tracker", None)
|
||||||
|
if enable_token_tracking and hasattr(app.state, "token_tracker")
|
||||||
|
else None,
|
||||||
|
)
|
||||||
elif binding == "jina":
|
elif binding == "jina":
|
||||||
from lightrag.llm.jina import jina_embed
|
from lightrag.llm.jina import jina_embed
|
||||||
|
|
||||||
return await jina_embed(
|
return await jina_embed(
|
||||||
texts, dimensions=dimensions, base_url=host, api_key=api_key
|
texts,
|
||||||
|
dimensions=dimensions,
|
||||||
|
base_url=host,
|
||||||
|
api_key=api_key,
|
||||||
|
token_tracker=getattr(app.state, "token_tracker", None)
|
||||||
|
if enable_token_tracking and hasattr(app.state, "token_tracker")
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
else: # openai and compatible
|
else: # openai and compatible
|
||||||
from lightrag.llm.openai import openai_embed
|
from lightrag.llm.openai import openai_embed
|
||||||
|
|
||||||
return await openai_embed(
|
return await openai_embed(
|
||||||
texts, model=model, base_url=host, api_key=api_key
|
texts,
|
||||||
|
model=model,
|
||||||
|
base_url=host,
|
||||||
|
api_key=api_key,
|
||||||
|
token_tracker=getattr(app.state, "token_tracker", None)
|
||||||
|
if enable_token_tracking and hasattr(app.state, "token_tracker")
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise Exception(f"Failed to import {binding} embedding: {e}")
|
raise Exception(f"Failed to import {binding} embedding: {e}")
|
||||||
|
|
@ -594,7 +668,9 @@ def create_app(args):
|
||||||
rag = LightRAG(
|
rag = LightRAG(
|
||||||
working_dir=args.working_dir,
|
working_dir=args.working_dir,
|
||||||
workspace=args.workspace,
|
workspace=args.workspace,
|
||||||
llm_model_func=create_llm_model_func(args.llm_binding),
|
llm_model_func=create_llm_model_func(
|
||||||
|
args.llm_binding, args.enable_token_tracking
|
||||||
|
),
|
||||||
llm_model_name=args.llm_model,
|
llm_model_name=args.llm_model,
|
||||||
llm_model_max_async=args.max_async,
|
llm_model_max_async=args.max_async,
|
||||||
summary_max_tokens=args.summary_max_tokens,
|
summary_max_tokens=args.summary_max_tokens,
|
||||||
|
|
@ -604,7 +680,16 @@ def create_app(args):
|
||||||
llm_model_kwargs=create_llm_model_kwargs(
|
llm_model_kwargs=create_llm_model_kwargs(
|
||||||
args.llm_binding, args, llm_timeout
|
args.llm_binding, args, llm_timeout
|
||||||
),
|
),
|
||||||
embedding_func=embedding_func,
|
embedding_func=create_optimized_embedding_function(
|
||||||
|
config_cache,
|
||||||
|
args.embedding_binding,
|
||||||
|
args.embedding_model,
|
||||||
|
args.embedding_binding_host,
|
||||||
|
args.embedding_binding_api_key,
|
||||||
|
args.embedding_dim,
|
||||||
|
args,
|
||||||
|
args.enable_token_tracking,
|
||||||
|
),
|
||||||
default_llm_timeout=llm_timeout,
|
default_llm_timeout=llm_timeout,
|
||||||
default_embedding_timeout=embedding_timeout,
|
default_embedding_timeout=embedding_timeout,
|
||||||
kv_storage=args.kv_storage,
|
kv_storage=args.kv_storage,
|
||||||
|
|
@ -629,6 +714,17 @@ def create_app(args):
|
||||||
logger.error(f"Failed to initialize LightRAG: {e}")
|
logger.error(f"Failed to initialize LightRAG: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# Initialize token tracking if enabled
|
||||||
|
token_tracker = None
|
||||||
|
if args.enable_token_tracking:
|
||||||
|
from lightrag.utils import TokenTracker
|
||||||
|
|
||||||
|
token_tracker = TokenTracker()
|
||||||
|
logger.info("Token tracking enabled")
|
||||||
|
|
||||||
|
# Add token tracker to the app state for use in endpoints
|
||||||
|
app.state.token_tracker = token_tracker
|
||||||
|
|
||||||
# Add routes
|
# Add routes
|
||||||
app.include_router(
|
app.include_router(
|
||||||
create_document_routes(
|
create_document_routes(
|
||||||
|
|
@ -637,7 +733,7 @@ def create_app(args):
|
||||||
api_key,
|
api_key,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
app.include_router(create_query_routes(rag, api_key, args.top_k))
|
app.include_router(create_query_routes(rag, api_key, args.top_k, token_tracker))
|
||||||
app.include_router(create_graph_routes(rag, api_key))
|
app.include_router(create_graph_routes(rag, api_key))
|
||||||
|
|
||||||
# Add Ollama API routes
|
# Add Ollama API routes
|
||||||
|
|
|
||||||
|
|
@ -243,6 +243,7 @@ class InsertResponse(BaseModel):
|
||||||
status: Status of the operation (success, duplicated, partial_success, failure)
|
status: Status of the operation (success, duplicated, partial_success, failure)
|
||||||
message: Detailed message describing the operation result
|
message: Detailed message describing the operation result
|
||||||
track_id: Tracking ID for monitoring processing status
|
track_id: Tracking ID for monitoring processing status
|
||||||
|
token_usage: Token usage statistics for the insertion (prompt_tokens, completion_tokens, total_tokens, call_count)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
status: Literal["success", "duplicated", "partial_success", "failure"] = Field(
|
status: Literal["success", "duplicated", "partial_success", "failure"] = Field(
|
||||||
|
|
@ -250,6 +251,10 @@ class InsertResponse(BaseModel):
|
||||||
)
|
)
|
||||||
message: str = Field(description="Message describing the operation result")
|
message: str = Field(description="Message describing the operation result")
|
||||||
track_id: str = Field(description="Tracking ID for monitoring processing status")
|
track_id: str = Field(description="Tracking ID for monitoring processing status")
|
||||||
|
token_usage: Optional[Dict[str, int]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Token usage statistics for the insertion (prompt_tokens, completion_tokens, total_tokens, call_count)",
|
||||||
|
)
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
json_schema_extra = {
|
json_schema_extra = {
|
||||||
|
|
@ -257,10 +262,17 @@ class InsertResponse(BaseModel):
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"message": "File 'document.pdf' uploaded successfully. Processing will continue in background.",
|
"message": "File 'document.pdf' uploaded successfully. Processing will continue in background.",
|
||||||
"track_id": "upload_20250729_170612_abc123",
|
"track_id": "upload_20250729_170612_abc123",
|
||||||
|
"token_usage": {
|
||||||
|
"prompt_tokens": 1250,
|
||||||
|
"completion_tokens": 450,
|
||||||
|
"total_tokens": 1700,
|
||||||
|
"call_count": 3,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ClearDocumentsResponse(BaseModel):
|
class ClearDocumentsResponse(BaseModel):
|
||||||
"""Response model for document clearing operation
|
"""Response model for document clearing operation
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
from lightrag.base import QueryParam
|
from lightrag.base import QueryParam
|
||||||
from lightrag.types import MetadataFilter
|
from lightrag.types import MetadataFilter
|
||||||
from lightrag.api.utils_api import get_combined_auth_dependency
|
from lightrag.api.utils_api import get_combined_auth_dependency
|
||||||
|
|
@ -157,6 +157,10 @@ class QueryResponse(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Reference list (Disabled when include_references=False, /query/data always includes references.)",
|
description="Reference list (Disabled when include_references=False, /query/data always includes references.)",
|
||||||
)
|
)
|
||||||
|
token_usage: Optional[Dict[str, int]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Token usage statistics for the query (prompt_tokens, completion_tokens, total_tokens)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QueryDataResponse(BaseModel):
|
class QueryDataResponse(BaseModel):
|
||||||
|
|
@ -168,6 +172,10 @@ class QueryDataResponse(BaseModel):
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, Any] = Field(
|
||||||
description="Query metadata including mode, keywords, and processing information"
|
description="Query metadata including mode, keywords, and processing information"
|
||||||
)
|
)
|
||||||
|
token_usage: Optional[Dict[str, int]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Token usage statistics for the query (prompt_tokens, completion_tokens, total_tokens)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StreamChunkResponse(BaseModel):
|
class StreamChunkResponse(BaseModel):
|
||||||
|
|
@ -183,9 +191,18 @@ class StreamChunkResponse(BaseModel):
|
||||||
error: Optional[str] = Field(
|
error: Optional[str] = Field(
|
||||||
default=None, description="Error message if processing fails"
|
default=None, description="Error message if processing fails"
|
||||||
)
|
)
|
||||||
|
token_usage: Optional[Dict[str, int]] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Token usage statistics for the entire query (only in final chunk)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
def create_query_routes(
|
||||||
|
rag,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
top_k: int = 60,
|
||||||
|
token_tracker: Optional[Any] = None,
|
||||||
|
):
|
||||||
combined_auth = get_combined_auth_dependency(api_key)
|
combined_auth = get_combined_auth_dependency(api_key)
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
|
|
@ -220,7 +237,7 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
},
|
},
|
||||||
"examples": {
|
"examples": {
|
||||||
"with_references": {
|
"with_references": {
|
||||||
"summary": "Response with references",
|
"summary": "Response with references and token usage",
|
||||||
"description": "Example response when include_references=True",
|
"description": "Example response when include_references=True",
|
||||||
"value": {
|
"value": {
|
||||||
"response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.",
|
"response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.",
|
||||||
|
|
@ -234,13 +251,25 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
"file_path": "/documents/machine_learning.txt",
|
"file_path": "/documents/machine_learning.txt",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
"token_usage": {
|
||||||
|
"prompt_tokens": 245,
|
||||||
|
"completion_tokens": 87,
|
||||||
|
"total_tokens": 332,
|
||||||
|
"call_count": 1,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"without_references": {
|
"without_references": {
|
||||||
"summary": "Response without references",
|
"summary": "Response without references but with token usage",
|
||||||
"description": "Example response when include_references=False",
|
"description": "Example response when include_references=False",
|
||||||
"value": {
|
"value": {
|
||||||
"response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving."
|
"response": "Artificial Intelligence (AI) is a branch of computer science that aims to create intelligent machines capable of performing tasks that typically require human intelligence, such as learning, reasoning, and problem-solving.",
|
||||||
|
"token_usage": {
|
||||||
|
"prompt_tokens": 245,
|
||||||
|
"completion_tokens": 87,
|
||||||
|
"total_tokens": 332,
|
||||||
|
"call_count": 1,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"different_modes": {
|
"different_modes": {
|
||||||
|
|
@ -358,6 +387,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
- 500: Internal processing error (e.g., LLM service unavailable)
|
- 500: Internal processing error (e.g., LLM service unavailable)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Reset token tracker at start of query if available
|
||||||
|
if token_tracker:
|
||||||
|
token_tracker.reset()
|
||||||
|
|
||||||
param = request.to_query_params(
|
param = request.to_query_params(
|
||||||
False
|
False
|
||||||
) # Ensure stream=False for non-streaming endpoint
|
) # Ensure stream=False for non-streaming endpoint
|
||||||
|
|
@ -376,11 +409,22 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
if not response_content:
|
if not response_content:
|
||||||
response_content = "No relevant context found for the query."
|
response_content = "No relevant context found for the query."
|
||||||
|
|
||||||
|
# Get token usage if available
|
||||||
|
token_usage = None
|
||||||
|
if token_tracker:
|
||||||
|
token_usage = token_tracker.get_usage()
|
||||||
|
|
||||||
# Return response with or without references based on request
|
# Return response with or without references based on request
|
||||||
if request.include_references:
|
if request.include_references:
|
||||||
return QueryResponse(response=response_content, references=references)
|
return QueryResponse(
|
||||||
|
response=response_content,
|
||||||
|
references=references,
|
||||||
|
token_usage=token_usage,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return QueryResponse(response=response_content, references=None)
|
return QueryResponse(
|
||||||
|
response=response_content, references=None, token_usage=token_usage
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
trace_exception(e)
|
trace_exception(e)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
@ -577,6 +621,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
Use streaming mode for real-time interfaces and non-streaming for batch processing.
|
Use streaming mode for real-time interfaces and non-streaming for batch processing.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Reset token tracker at start of query if available
|
||||||
|
if token_tracker:
|
||||||
|
token_tracker.reset()
|
||||||
|
|
||||||
# Use the stream parameter from the request, defaulting to True if not specified
|
# Use the stream parameter from the request, defaulting to True if not specified
|
||||||
stream_mode = request.stream if request.stream is not None else True
|
stream_mode = request.stream if request.stream is not None else True
|
||||||
param = request.to_query_params(stream_mode)
|
param = request.to_query_params(stream_mode)
|
||||||
|
|
@ -605,6 +653,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Streaming error: {str(e)}")
|
logging.error(f"Streaming error: {str(e)}")
|
||||||
yield f"{json.dumps({'error': str(e)})}\n"
|
yield f"{json.dumps({'error': str(e)})}\n"
|
||||||
|
|
||||||
|
# Add final token usage chunk if streaming and token tracker is available
|
||||||
|
if token_tracker and llm_response.get("is_streaming"):
|
||||||
|
yield f"{json.dumps({'token_usage': token_tracker.get_usage()})}\n"
|
||||||
else:
|
else:
|
||||||
# Non-streaming mode: send complete response in one message
|
# Non-streaming mode: send complete response in one message
|
||||||
response_content = llm_response.get("content", "")
|
response_content = llm_response.get("content", "")
|
||||||
|
|
@ -616,6 +668,10 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
if request.include_references:
|
if request.include_references:
|
||||||
complete_response["references"] = references
|
complete_response["references"] = references
|
||||||
|
|
||||||
|
# Add token usage if available
|
||||||
|
if token_tracker:
|
||||||
|
complete_response["token_usage"] = token_tracker.get_usage()
|
||||||
|
|
||||||
yield f"{json.dumps(complete_response)}\n"
|
yield f"{json.dumps(complete_response)}\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
|
|
@ -1022,18 +1078,31 @@ def create_query_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
||||||
as structured data analysis typically requires source attribution.
|
as structured data analysis typically requires source attribution.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Reset token tracker at start of query if available
|
||||||
|
if token_tracker:
|
||||||
|
token_tracker.reset()
|
||||||
|
|
||||||
param = request.to_query_params(False) # No streaming for data endpoint
|
param = request.to_query_params(False) # No streaming for data endpoint
|
||||||
response = await rag.aquery_data(request.query, param=param)
|
response = await rag.aquery_data(request.query, param=param)
|
||||||
|
|
||||||
|
# Get token usage if available
|
||||||
|
token_usage = None
|
||||||
|
if token_tracker:
|
||||||
|
token_usage = token_tracker.get_usage()
|
||||||
|
|
||||||
# aquery_data returns the new format with status, message, data, and metadata
|
# aquery_data returns the new format with status, message, data, and metadata
|
||||||
if isinstance(response, dict):
|
if isinstance(response, dict):
|
||||||
return QueryDataResponse(**response)
|
response_dict = dict(response)
|
||||||
|
response_dict["token_usage"] = token_usage
|
||||||
|
return QueryDataResponse(**response_dict)
|
||||||
else:
|
else:
|
||||||
# Handle unexpected response format
|
# Handle unexpected response format
|
||||||
return QueryDataResponse(
|
return QueryDataResponse(
|
||||||
status="failure",
|
status="failure",
|
||||||
message="Invalid response type",
|
message="Unexpected response format",
|
||||||
data={},
|
data={},
|
||||||
|
metadata={},
|
||||||
|
token_usage=None,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
trace_exception(e)
|
trace_exception(e)
|
||||||
|
|
|
||||||
|
|
@ -870,7 +870,8 @@ class LightRAG:
|
||||||
ids: str | list[str] | None = None,
|
ids: str | list[str] | None = None,
|
||||||
file_paths: str | list[str] | None = None,
|
file_paths: str | list[str] | None = None,
|
||||||
track_id: str | None = None,
|
track_id: str | None = None,
|
||||||
) -> str:
|
token_tracker: TokenTracker | None = None,
|
||||||
|
) -> tuple[str, dict]:
|
||||||
"""Sync Insert documents with checkpoint support
|
"""Sync Insert documents with checkpoint support
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -884,10 +885,10 @@ class LightRAG:
|
||||||
track_id: tracking ID for monitoring processing status, if not provided, will be generated
|
track_id: tracking ID for monitoring processing status, if not provided, will be generated
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: tracking ID for monitoring processing status
|
tuple[str, dict]: (tracking ID for monitoring processing status, token usage statistics)
|
||||||
"""
|
"""
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
return loop.run_until_complete(
|
result = loop.run_until_complete(
|
||||||
self.ainsert(
|
self.ainsert(
|
||||||
input,
|
input,
|
||||||
split_by_character,
|
split_by_character,
|
||||||
|
|
@ -895,8 +896,10 @@ class LightRAG:
|
||||||
ids,
|
ids,
|
||||||
file_paths,
|
file_paths,
|
||||||
track_id,
|
track_id,
|
||||||
|
token_tracker,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
async def ainsert(
|
async def ainsert(
|
||||||
self,
|
self,
|
||||||
|
|
@ -906,7 +909,8 @@ class LightRAG:
|
||||||
ids: str | list[str] | None = None,
|
ids: str | list[str] | None = None,
|
||||||
file_paths: str | list[str] | None = None,
|
file_paths: str | list[str] | None = None,
|
||||||
track_id: str | None = None,
|
track_id: str | None = None,
|
||||||
) -> str:
|
token_tracker: TokenTracker | None = None,
|
||||||
|
) -> tuple[str, dict]:
|
||||||
"""Async Insert documents with checkpoint support
|
"""Async Insert documents with checkpoint support
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -930,9 +934,15 @@ class LightRAG:
|
||||||
await self.apipeline_process_enqueue_documents(
|
await self.apipeline_process_enqueue_documents(
|
||||||
split_by_character,
|
split_by_character,
|
||||||
split_by_character_only,
|
split_by_character_only,
|
||||||
|
token_tracker,
|
||||||
)
|
)
|
||||||
|
|
||||||
return track_id
|
return track_id, token_tracker.get_usage() if token_tracker else {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0,
|
||||||
|
"call_count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
# TODO: deprecated, use insert instead
|
# TODO: deprecated, use insert instead
|
||||||
def insert_custom_chunks(
|
def insert_custom_chunks(
|
||||||
|
|
@ -1107,7 +1117,9 @@ class LightRAG:
|
||||||
"file_path"
|
"file_path"
|
||||||
], # Store file path in document status
|
], # Store file path in document status
|
||||||
"track_id": track_id, # Store track_id in document status
|
"track_id": track_id, # Store track_id in document status
|
||||||
"metadata": metadata[i] if isinstance(metadata, list) and i < len(metadata) else metadata, # added provided custom metadata per document
|
"metadata": metadata[i]
|
||||||
|
if isinstance(metadata, list) and i < len(metadata)
|
||||||
|
else metadata, # added provided custom metadata per document
|
||||||
}
|
}
|
||||||
for i, (id_, content_data) in enumerate(contents.items())
|
for i, (id_, content_data) in enumerate(contents.items())
|
||||||
}
|
}
|
||||||
|
|
@ -1363,6 +1375,7 @@ class LightRAG:
|
||||||
self,
|
self,
|
||||||
split_by_character: str | None = None,
|
split_by_character: str | None = None,
|
||||||
split_by_character_only: bool = False,
|
split_by_character_only: bool = False,
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Process pending documents by splitting them into chunks, processing
|
Process pending documents by splitting them into chunks, processing
|
||||||
|
|
@ -1484,9 +1497,12 @@ class LightRAG:
|
||||||
pipeline_status: dict,
|
pipeline_status: dict,
|
||||||
pipeline_status_lock: asyncio.Lock,
|
pipeline_status_lock: asyncio.Lock,
|
||||||
semaphore: asyncio.Semaphore,
|
semaphore: asyncio.Semaphore,
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process single document"""
|
"""Process single document"""
|
||||||
doc_metadata = getattr(status_doc, "metadata", None)
|
doc_metadata = getattr(status_doc, "metadata", None)
|
||||||
|
if doc_metadata is None:
|
||||||
|
doc_metadata = {}
|
||||||
file_extraction_stage_ok = False
|
file_extraction_stage_ok = False
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
nonlocal processed_count
|
nonlocal processed_count
|
||||||
|
|
@ -1498,7 +1514,6 @@ class LightRAG:
|
||||||
# Get file path from status document
|
# Get file path from status document
|
||||||
file_path = getattr(
|
file_path = getattr(
|
||||||
status_doc, "file_path", "unknown_source"
|
status_doc, "file_path", "unknown_source"
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async with pipeline_status_lock:
|
async with pipeline_status_lock:
|
||||||
|
|
@ -1561,7 +1576,9 @@ class LightRAG:
|
||||||
|
|
||||||
# Process document in two stages
|
# Process document in two stages
|
||||||
# Stage 1: Process text chunks and docs (parallel execution)
|
# Stage 1: Process text chunks and docs (parallel execution)
|
||||||
doc_metadata["processing_start_time"] = processing_start_time
|
doc_metadata["processing_start_time"] = (
|
||||||
|
processing_start_time
|
||||||
|
)
|
||||||
|
|
||||||
doc_status_task = asyncio.create_task(
|
doc_status_task = asyncio.create_task(
|
||||||
self.doc_status.upsert(
|
self.doc_status.upsert(
|
||||||
|
|
@ -1610,6 +1627,7 @@ class LightRAG:
|
||||||
doc_metadata,
|
doc_metadata,
|
||||||
pipeline_status,
|
pipeline_status,
|
||||||
pipeline_status_lock,
|
pipeline_status_lock,
|
||||||
|
token_tracker,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await entity_relation_task
|
await entity_relation_task
|
||||||
|
|
@ -1643,7 +1661,9 @@ class LightRAG:
|
||||||
processing_end_time = int(time.time())
|
processing_end_time = int(time.time())
|
||||||
|
|
||||||
# Update document status to failed
|
# Update document status to failed
|
||||||
doc_metadata["processing_start_time"] = processing_start_time
|
doc_metadata["processing_start_time"] = (
|
||||||
|
processing_start_time
|
||||||
|
)
|
||||||
doc_metadata["processing_end_time"] = processing_end_time
|
doc_metadata["processing_end_time"] = processing_end_time
|
||||||
await self.doc_status.upsert(
|
await self.doc_status.upsert(
|
||||||
{
|
{
|
||||||
|
|
@ -1691,7 +1711,9 @@ class LightRAG:
|
||||||
doc_metadata["processing_start_time"] = (
|
doc_metadata["processing_start_time"] = (
|
||||||
processing_start_time
|
processing_start_time
|
||||||
)
|
)
|
||||||
doc_metadata["processing_end_time"] = processing_end_time
|
doc_metadata["processing_end_time"] = (
|
||||||
|
processing_end_time
|
||||||
|
)
|
||||||
|
|
||||||
await self.doc_status.upsert(
|
await self.doc_status.upsert(
|
||||||
{
|
{
|
||||||
|
|
@ -1748,7 +1770,9 @@ class LightRAG:
|
||||||
doc_metadata["processing_start_time"] = (
|
doc_metadata["processing_start_time"] = (
|
||||||
processing_start_time
|
processing_start_time
|
||||||
)
|
)
|
||||||
doc_metadata["processing_end_time"] = processing_end_time
|
doc_metadata["processing_end_time"] = (
|
||||||
|
processing_end_time
|
||||||
|
)
|
||||||
|
|
||||||
await self.doc_status.upsert(
|
await self.doc_status.upsert(
|
||||||
{
|
{
|
||||||
|
|
@ -1778,6 +1802,7 @@ class LightRAG:
|
||||||
pipeline_status,
|
pipeline_status,
|
||||||
pipeline_status_lock,
|
pipeline_status_lock,
|
||||||
semaphore,
|
semaphore,
|
||||||
|
token_tracker,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1827,6 +1852,7 @@ class LightRAG:
|
||||||
metadata: dict | None,
|
metadata: dict | None,
|
||||||
pipeline_status=None,
|
pipeline_status=None,
|
||||||
pipeline_status_lock=None,
|
pipeline_status_lock=None,
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
) -> list:
|
) -> list:
|
||||||
try:
|
try:
|
||||||
chunk_results = await extract_entities(
|
chunk_results = await extract_entities(
|
||||||
|
|
@ -1837,6 +1863,7 @@ class LightRAG:
|
||||||
pipeline_status_lock=pipeline_status_lock,
|
pipeline_status_lock=pipeline_status_lock,
|
||||||
llm_response_cache=self.llm_response_cache,
|
llm_response_cache=self.llm_response_cache,
|
||||||
text_chunks_storage=self.text_chunks,
|
text_chunks_storage=self.text_chunks,
|
||||||
|
token_tracker=token_tracker,
|
||||||
)
|
)
|
||||||
return chunk_results
|
return chunk_results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -2061,14 +2088,18 @@ class LightRAG:
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
param: QueryParam = QueryParam(),
|
param: QueryParam = QueryParam(),
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
) -> str | Iterator[str]:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Perform a sync query.
|
User query interface (backward compatibility wrapper).
|
||||||
|
|
||||||
|
Delegates to aquery() for asynchronous execution and returns the result.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query (str): The query to be executed.
|
query (str): The query to be executed.
|
||||||
param (QueryParam): Configuration parameters for query execution.
|
param (QueryParam): Configuration parameters for query execution.
|
||||||
|
token_tracker (TokenTracker | None): Optional token tracker for monitoring usage.
|
||||||
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
@ -2076,14 +2107,17 @@ class LightRAG:
|
||||||
"""
|
"""
|
||||||
loop = always_get_an_event_loop()
|
loop = always_get_an_event_loop()
|
||||||
|
|
||||||
return loop.run_until_complete(self.aquery(query, param, system_prompt)) # type: ignore
|
return loop.run_until_complete(
|
||||||
|
self.aquery(query, param, token_tracker, system_prompt)
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
async def aquery(
|
async def aquery(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
param: QueryParam = QueryParam(),
|
param: QueryParam = QueryParam(),
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
) -> str | AsyncIterator[str]:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Perform a async query (backward compatibility wrapper).
|
Perform a async query (backward compatibility wrapper).
|
||||||
|
|
||||||
|
|
@ -2102,7 +2136,7 @@ class LightRAG:
|
||||||
- Streaming: Returns AsyncIterator[str]
|
- Streaming: Returns AsyncIterator[str]
|
||||||
"""
|
"""
|
||||||
# Call the new aquery_llm function to get complete results
|
# Call the new aquery_llm function to get complete results
|
||||||
result = await self.aquery_llm(query, param, system_prompt)
|
result = await self.aquery_llm(query, param, system_prompt, token_tracker)
|
||||||
|
|
||||||
# Extract and return only the LLM response for backward compatibility
|
# Extract and return only the LLM response for backward compatibility
|
||||||
llm_response = result.get("llm_response", {})
|
llm_response = result.get("llm_response", {})
|
||||||
|
|
@ -2137,6 +2171,7 @@ class LightRAG:
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
param: QueryParam = QueryParam(),
|
param: QueryParam = QueryParam(),
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Asynchronous data retrieval API: returns structured retrieval results without LLM generation.
|
Asynchronous data retrieval API: returns structured retrieval results without LLM generation.
|
||||||
|
|
@ -2330,6 +2365,7 @@ class LightRAG:
|
||||||
query: str,
|
query: str,
|
||||||
param: QueryParam = QueryParam(),
|
param: QueryParam = QueryParam(),
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Asynchronous complete query API: returns structured retrieval results with LLM generation.
|
Asynchronous complete query API: returns structured retrieval results with LLM generation.
|
||||||
|
|
@ -2364,6 +2400,7 @@ class LightRAG:
|
||||||
hashing_kv=self.llm_response_cache,
|
hashing_kv=self.llm_response_cache,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
chunks_vdb=self.chunks_vdb,
|
chunks_vdb=self.chunks_vdb,
|
||||||
|
token_tracker=token_tracker,
|
||||||
)
|
)
|
||||||
elif param.mode == "naive":
|
elif param.mode == "naive":
|
||||||
query_result = await naive_query(
|
query_result = await naive_query(
|
||||||
|
|
@ -2373,6 +2410,7 @@ class LightRAG:
|
||||||
global_config,
|
global_config,
|
||||||
hashing_kv=self.llm_response_cache,
|
hashing_kv=self.llm_response_cache,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
|
token_tracker=token_tracker,
|
||||||
)
|
)
|
||||||
elif param.mode == "bypass":
|
elif param.mode == "bypass":
|
||||||
# Bypass mode: directly use LLM without knowledge retrieval
|
# Bypass mode: directly use LLM without knowledge retrieval
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ async def azure_openai_complete_if_cache(
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
api_version: str | None = None,
|
api_version: str | None = None,
|
||||||
|
token_tracker: Any | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if enable_cot:
|
if enable_cot:
|
||||||
|
|
@ -94,28 +95,73 @@ async def azure_openai_complete_if_cache(
|
||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(response, "__aiter__"):
|
if hasattr(response, "__aiter__"):
|
||||||
|
final_chunk_usage = None
|
||||||
|
accumulated_response = ""
|
||||||
|
|
||||||
async def inner():
|
async def inner():
|
||||||
async for chunk in response:
|
nonlocal final_chunk_usage, accumulated_response
|
||||||
if len(chunk.choices) == 0:
|
try:
|
||||||
continue
|
async for chunk in response:
|
||||||
content = chunk.choices[0].delta.content
|
if len(chunk.choices) == 0:
|
||||||
if content is None:
|
continue
|
||||||
continue
|
content = chunk.choices[0].delta.content
|
||||||
if r"\u" in content:
|
if content is None:
|
||||||
content = safe_unicode_decode(content.encode("utf-8"))
|
continue
|
||||||
yield content
|
accumulated_response += content
|
||||||
|
if r"\u" in content:
|
||||||
|
content = safe_unicode_decode(content.encode("utf-8"))
|
||||||
|
yield content
|
||||||
|
|
||||||
|
# Check for usage in the last chunk
|
||||||
|
if hasattr(chunk, "usage") and chunk.usage is not None:
|
||||||
|
final_chunk_usage = chunk.usage
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in Azure OpenAI stream response: {str(e)}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# After streaming is complete, track token usage
|
||||||
|
if token_tracker and final_chunk_usage:
|
||||||
|
# Use actual usage from the API
|
||||||
|
token_counts = {
|
||||||
|
"prompt_tokens": getattr(final_chunk_usage, "prompt_tokens", 0),
|
||||||
|
"completion_tokens": getattr(
|
||||||
|
final_chunk_usage, "completion_tokens", 0
|
||||||
|
),
|
||||||
|
"total_tokens": getattr(final_chunk_usage, "total_tokens", 0),
|
||||||
|
}
|
||||||
|
token_tracker.add_usage(token_counts)
|
||||||
|
logger.debug(f"Azure OpenAI streaming token usage: {token_counts}")
|
||||||
|
elif token_tracker:
|
||||||
|
logger.debug(
|
||||||
|
"No usage information available in Azure OpenAI streaming response"
|
||||||
|
)
|
||||||
|
|
||||||
return inner()
|
return inner()
|
||||||
else:
|
else:
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
if r"\u" in content:
|
if r"\u" in content:
|
||||||
content = safe_unicode_decode(content.encode("utf-8"))
|
content = safe_unicode_decode(content.encode("utf-8"))
|
||||||
|
|
||||||
|
# Track token usage for non-streaming response
|
||||||
|
if token_tracker and hasattr(response, "usage"):
|
||||||
|
token_counts = {
|
||||||
|
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
|
||||||
|
"completion_tokens": getattr(response.usage, "completion_tokens", 0),
|
||||||
|
"total_tokens": getattr(response.usage, "total_tokens", 0),
|
||||||
|
}
|
||||||
|
token_tracker.add_usage(token_counts)
|
||||||
|
logger.debug(f"Azure OpenAI non-streaming token usage: {token_counts}")
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
async def azure_openai_complete(
|
async def azure_openai_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
prompt,
|
||||||
|
system_prompt=None,
|
||||||
|
history_messages=[],
|
||||||
|
keyword_extraction=False,
|
||||||
|
token_tracker=None,
|
||||||
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
kwargs.pop("keyword_extraction", None)
|
kwargs.pop("keyword_extraction", None)
|
||||||
result = await azure_openai_complete_if_cache(
|
result = await azure_openai_complete_if_cache(
|
||||||
|
|
@ -123,6 +169,7 @@ async def azure_openai_complete(
|
||||||
prompt,
|
prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
|
token_tracker=token_tracker,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
@ -142,6 +189,7 @@ async def azure_openai_embed(
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
api_version: str | None = None,
|
api_version: str | None = None,
|
||||||
|
token_tracker: Any | None = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
deployment = (
|
deployment = (
|
||||||
os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
|
os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
|
||||||
|
|
@ -174,4 +222,14 @@ async def azure_openai_embed(
|
||||||
response = await openai_async_client.embeddings.create(
|
response = await openai_async_client.embeddings.create(
|
||||||
model=model, input=texts, encoding_format="float"
|
model=model, input=texts, encoding_format="float"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Track token usage for embeddings if token tracker is provided
|
||||||
|
if token_tracker and hasattr(response, "usage"):
|
||||||
|
token_counts = {
|
||||||
|
"prompt_tokens": getattr(response.usage, "prompt_tokens", 0),
|
||||||
|
"total_tokens": getattr(response.usage, "total_tokens", 0),
|
||||||
|
}
|
||||||
|
token_tracker.add_usage(token_counts)
|
||||||
|
logger.debug(f"Azure OpenAI embedding token usage: {token_counts}")
|
||||||
|
|
||||||
return np.array([dp.embedding for dp in response.data])
|
return np.array([dp.embedding for dp in response.data])
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,7 @@ async def bedrock_complete_if_cache(
|
||||||
aws_access_key_id=None,
|
aws_access_key_id=None,
|
||||||
aws_secret_access_key=None,
|
aws_secret_access_key=None,
|
||||||
aws_session_token=None,
|
aws_session_token=None,
|
||||||
|
token_tracker=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, AsyncIterator[str]]:
|
) -> Union[str, AsyncIterator[str]]:
|
||||||
if enable_cot:
|
if enable_cot:
|
||||||
|
|
@ -155,6 +156,18 @@ async def bedrock_complete_if_cache(
|
||||||
yield text
|
yield text
|
||||||
# Handle other event types that might indicate stream end
|
# Handle other event types that might indicate stream end
|
||||||
elif "messageStop" in event:
|
elif "messageStop" in event:
|
||||||
|
# Track token usage for streaming if token tracker is provided
|
||||||
|
if token_tracker and "usage" in event:
|
||||||
|
usage = event["usage"]
|
||||||
|
token_counts = {
|
||||||
|
"prompt_tokens": usage.get("inputTokens", 0),
|
||||||
|
"completion_tokens": usage.get("outputTokens", 0),
|
||||||
|
"total_tokens": usage.get("totalTokens", 0),
|
||||||
|
}
|
||||||
|
token_tracker.add_usage(token_counts)
|
||||||
|
logging.debug(
|
||||||
|
f"Bedrock streaming token usage: {token_counts}"
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -228,6 +241,17 @@ async def bedrock_complete_if_cache(
|
||||||
if not content or content.strip() == "":
|
if not content or content.strip() == "":
|
||||||
raise BedrockError("Received empty content from Bedrock API")
|
raise BedrockError("Received empty content from Bedrock API")
|
||||||
|
|
||||||
|
# Track token usage for non-streaming if token tracker is provided
|
||||||
|
if token_tracker and "usage" in response:
|
||||||
|
usage = response["usage"]
|
||||||
|
token_counts = {
|
||||||
|
"prompt_tokens": usage.get("inputTokens", 0),
|
||||||
|
"completion_tokens": usage.get("outputTokens", 0),
|
||||||
|
"total_tokens": usage.get("totalTokens", 0),
|
||||||
|
}
|
||||||
|
token_tracker.add_usage(token_counts)
|
||||||
|
logging.debug(f"Bedrock non-streaming token usage: {token_counts}")
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -239,7 +263,12 @@ async def bedrock_complete_if_cache(
|
||||||
|
|
||||||
# Generic Bedrock completion function
|
# Generic Bedrock completion function
|
||||||
async def bedrock_complete(
|
async def bedrock_complete(
|
||||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
prompt,
|
||||||
|
system_prompt=None,
|
||||||
|
history_messages=[],
|
||||||
|
keyword_extraction=False,
|
||||||
|
token_tracker=None,
|
||||||
|
**kwargs,
|
||||||
) -> Union[str, AsyncIterator[str]]:
|
) -> Union[str, AsyncIterator[str]]:
|
||||||
kwargs.pop("keyword_extraction", None)
|
kwargs.pop("keyword_extraction", None)
|
||||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||||
|
|
@ -248,6 +277,7 @@ async def bedrock_complete(
|
||||||
prompt,
|
prompt,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
|
token_tracker=token_tracker,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
@ -265,6 +295,7 @@ async def bedrock_embed(
|
||||||
aws_access_key_id=None,
|
aws_access_key_id=None,
|
||||||
aws_secret_access_key=None,
|
aws_secret_access_key=None,
|
||||||
aws_session_token=None,
|
aws_session_token=None,
|
||||||
|
token_tracker=None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
# Respect existing env; only set if a non-empty value is available
|
# Respect existing env; only set if a non-empty value is available
|
||||||
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
|
access_key = os.environ.get("AWS_ACCESS_KEY_ID") or aws_access_key_id
|
||||||
|
|
|
||||||
|
|
@ -108,6 +108,7 @@ async def lollms_model_complete(
|
||||||
history_messages=[],
|
history_messages=[],
|
||||||
enable_cot: bool = False,
|
enable_cot: bool = False,
|
||||||
keyword_extraction=False,
|
keyword_extraction=False,
|
||||||
|
token_tracker=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, AsyncIterator[str]]:
|
) -> Union[str, AsyncIterator[str]]:
|
||||||
"""Complete function for lollms model generation."""
|
"""Complete function for lollms model generation."""
|
||||||
|
|
@ -135,7 +136,11 @@ async def lollms_model_complete(
|
||||||
|
|
||||||
|
|
||||||
async def lollms_embed(
|
async def lollms_embed(
|
||||||
texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
|
texts: List[str],
|
||||||
|
embed_model=None,
|
||||||
|
base_url="http://localhost:9600",
|
||||||
|
token_tracker=None,
|
||||||
|
**kwargs,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Generate embeddings for a list of texts using lollms server.
|
Generate embeddings for a list of texts using lollms server.
|
||||||
|
|
@ -144,6 +149,7 @@ async def lollms_embed(
|
||||||
texts: List of strings to embed
|
texts: List of strings to embed
|
||||||
embed_model: Model name (not used directly as lollms uses configured vectorizer)
|
embed_model: Model name (not used directly as lollms uses configured vectorizer)
|
||||||
base_url: URL of the lollms server
|
base_url: URL of the lollms server
|
||||||
|
token_tracker: Optional token usage tracker for monitoring API usage
|
||||||
**kwargs: Additional arguments passed to the request
|
**kwargs: Additional arguments passed to the request
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ async def _ollama_model_if_cache(
|
||||||
system_prompt=None,
|
system_prompt=None,
|
||||||
history_messages=[],
|
history_messages=[],
|
||||||
enable_cot: bool = False,
|
enable_cot: bool = False,
|
||||||
|
token_tracker=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, AsyncIterator[str]]:
|
) -> Union[str, AsyncIterator[str]]:
|
||||||
if enable_cot:
|
if enable_cot:
|
||||||
|
|
@ -74,13 +75,47 @@ async def _ollama_model_if_cache(
|
||||||
"""cannot cache stream response and process reasoning"""
|
"""cannot cache stream response and process reasoning"""
|
||||||
|
|
||||||
async def inner():
|
async def inner():
|
||||||
|
accumulated_response = ""
|
||||||
try:
|
try:
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
yield chunk["message"]["content"]
|
chunk_content = chunk["message"]["content"]
|
||||||
|
accumulated_response += chunk_content
|
||||||
|
yield chunk_content
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in stream response: {str(e)}")
|
logger.error(f"Error in stream response: {str(e)}")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
# Track token usage for streaming if token tracker is provided
|
||||||
|
if token_tracker:
|
||||||
|
# Estimate prompt tokens: roughly 4 characters per token for English text
|
||||||
|
prompt_text = ""
|
||||||
|
if system_prompt:
|
||||||
|
prompt_text += system_prompt + " "
|
||||||
|
prompt_text += (
|
||||||
|
" ".join(
|
||||||
|
[msg.get("content", "") for msg in history_messages]
|
||||||
|
)
|
||||||
|
+ " "
|
||||||
|
)
|
||||||
|
prompt_text += prompt
|
||||||
|
prompt_tokens = len(prompt_text) // 4 + (
|
||||||
|
1 if len(prompt_text) % 4 else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Estimate completion tokens from accumulated response
|
||||||
|
completion_tokens = len(accumulated_response) // 4 + (
|
||||||
|
1 if len(accumulated_response) % 4 else 0
|
||||||
|
)
|
||||||
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
|
||||||
|
token_tracker.add_usage(
|
||||||
|
{
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await ollama_client._client.aclose()
|
await ollama_client._client.aclose()
|
||||||
logger.debug("Successfully closed Ollama client for streaming")
|
logger.debug("Successfully closed Ollama client for streaming")
|
||||||
|
|
@ -91,6 +126,35 @@ async def _ollama_model_if_cache(
|
||||||
else:
|
else:
|
||||||
model_response = response["message"]["content"]
|
model_response = response["message"]["content"]
|
||||||
|
|
||||||
|
# Track token usage if token tracker is provided
|
||||||
|
# Note: Ollama doesn't provide token usage in chat responses, so we estimate
|
||||||
|
if token_tracker:
|
||||||
|
# Estimate prompt tokens: roughly 4 characters per token for English text
|
||||||
|
prompt_text = ""
|
||||||
|
if system_prompt:
|
||||||
|
prompt_text += system_prompt + " "
|
||||||
|
prompt_text += (
|
||||||
|
" ".join([msg.get("content", "") for msg in history_messages]) + " "
|
||||||
|
)
|
||||||
|
prompt_text += prompt
|
||||||
|
prompt_tokens = len(prompt_text) // 4 + (
|
||||||
|
1 if len(prompt_text) % 4 else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Estimate completion tokens from response
|
||||||
|
completion_tokens = len(model_response) // 4 + (
|
||||||
|
1 if len(model_response) % 4 else 0
|
||||||
|
)
|
||||||
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
|
||||||
|
token_tracker.add_usage(
|
||||||
|
{
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
If the model also wraps its thoughts in a specific tag,
|
If the model also wraps its thoughts in a specific tag,
|
||||||
this information is not needed for the final
|
this information is not needed for the final
|
||||||
|
|
@ -126,6 +190,7 @@ async def ollama_model_complete(
|
||||||
history_messages=[],
|
history_messages=[],
|
||||||
enable_cot: bool = False,
|
enable_cot: bool = False,
|
||||||
keyword_extraction=False,
|
keyword_extraction=False,
|
||||||
|
token_tracker=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, AsyncIterator[str]]:
|
) -> Union[str, AsyncIterator[str]]:
|
||||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||||
|
|
@ -138,11 +203,14 @@ async def ollama_model_complete(
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history_messages=history_messages,
|
history_messages=history_messages,
|
||||||
enable_cot=enable_cot,
|
enable_cot=enable_cot,
|
||||||
|
token_tracker=token_tracker,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
async def ollama_embed(
|
||||||
|
texts: list[str], embed_model, token_tracker=None, **kwargs
|
||||||
|
) -> np.ndarray:
|
||||||
api_key = kwargs.pop("api_key", None)
|
api_key = kwargs.pop("api_key", None)
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
|
@ -160,6 +228,21 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
|
||||||
data = await ollama_client.embed(
|
data = await ollama_client.embed(
|
||||||
model=embed_model, input=texts, options=options
|
model=embed_model, input=texts, options=options
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Track token usage if token tracker is provided
|
||||||
|
# Note: Ollama doesn't provide token usage in embedding responses, so we estimate
|
||||||
|
if token_tracker:
|
||||||
|
# Estimate tokens: roughly 4 characters per token for English text
|
||||||
|
total_chars = sum(len(text) for text in texts)
|
||||||
|
estimated_tokens = total_chars // 4 + (1 if total_chars % 4 else 0)
|
||||||
|
token_tracker.add_usage(
|
||||||
|
{
|
||||||
|
"prompt_tokens": estimated_tokens,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": estimated_tokens,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return np.array(data["embeddings"])
|
return np.array(data["embeddings"])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in ollama_embed: {str(e)}")
|
logger.error(f"Error in ollama_embed: {str(e)}")
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from .utils import (
|
||||||
logger,
|
logger,
|
||||||
compute_mdhash_id,
|
compute_mdhash_id,
|
||||||
Tokenizer,
|
Tokenizer,
|
||||||
|
TokenTracker,
|
||||||
is_float_regex,
|
is_float_regex,
|
||||||
sanitize_and_normalize_extracted_text,
|
sanitize_and_normalize_extracted_text,
|
||||||
pack_user_ass_to_openai_messages,
|
pack_user_ass_to_openai_messages,
|
||||||
|
|
@ -126,6 +127,7 @@ async def _handle_entity_relation_summary(
|
||||||
seperator: str,
|
seperator: str,
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
llm_response_cache: BaseKVStorage | None = None,
|
llm_response_cache: BaseKVStorage | None = None,
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
) -> tuple[str, bool]:
|
) -> tuple[str, bool]:
|
||||||
"""Handle entity relation description summary using map-reduce approach.
|
"""Handle entity relation description summary using map-reduce approach.
|
||||||
|
|
||||||
|
|
@ -188,6 +190,7 @@ async def _handle_entity_relation_summary(
|
||||||
current_list,
|
current_list,
|
||||||
global_config,
|
global_config,
|
||||||
llm_response_cache,
|
llm_response_cache,
|
||||||
|
token_tracker,
|
||||||
)
|
)
|
||||||
return final_summary, True # LLM was used for final summarization
|
return final_summary, True # LLM was used for final summarization
|
||||||
|
|
||||||
|
|
@ -243,6 +246,7 @@ async def _handle_entity_relation_summary(
|
||||||
chunk,
|
chunk,
|
||||||
global_config,
|
global_config,
|
||||||
llm_response_cache,
|
llm_response_cache,
|
||||||
|
token_tracker,
|
||||||
)
|
)
|
||||||
new_summaries.append(summary)
|
new_summaries.append(summary)
|
||||||
llm_was_used = True # Mark that LLM was used in reduce phase
|
llm_was_used = True # Mark that LLM was used in reduce phase
|
||||||
|
|
@ -257,6 +261,7 @@ async def _summarize_descriptions(
|
||||||
description_list: list[str],
|
description_list: list[str],
|
||||||
global_config: dict,
|
global_config: dict,
|
||||||
llm_response_cache: BaseKVStorage | None = None,
|
llm_response_cache: BaseKVStorage | None = None,
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Helper function to summarize a list of descriptions using LLM.
|
"""Helper function to summarize a list of descriptions using LLM.
|
||||||
|
|
||||||
|
|
@ -312,9 +317,10 @@ async def _summarize_descriptions(
|
||||||
# Use LLM function with cache (higher priority for summary generation)
|
# Use LLM function with cache (higher priority for summary generation)
|
||||||
summary, _ = await use_llm_func_with_cache(
|
summary, _ = await use_llm_func_with_cache(
|
||||||
use_prompt,
|
use_prompt,
|
||||||
use_llm_func,
|
use_llm_func=use_llm_func,
|
||||||
llm_response_cache=llm_response_cache,
|
hashing_kv=llm_response_cache,
|
||||||
cache_type="summary",
|
cache_type="summary",
|
||||||
|
token_tracker=token_tracker,
|
||||||
)
|
)
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
@ -405,7 +411,7 @@ async def _handle_single_relationship_extraction(
|
||||||
): # treat "relationship" and "relation" interchangeable
|
): # treat "relationship" and "relation" interchangeable
|
||||||
if len(record_attributes) > 1 and "relation" in record_attributes[0]:
|
if len(record_attributes) > 1 and "relation" in record_attributes[0]:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on REALTION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) >2 else 'N/A'}`"
|
f"{chunk_key}: LLM output format error; found {len(record_attributes)}/5 fields on REALTION `{record_attributes[1]}`~`{record_attributes[2] if len(record_attributes) > 2 else 'N/A'}`"
|
||||||
)
|
)
|
||||||
logger.debug(record_attributes)
|
logger.debug(record_attributes)
|
||||||
return None
|
return None
|
||||||
|
|
@ -463,7 +469,6 @@ async def _handle_single_relationship_extraction(
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|
@ -2037,6 +2042,7 @@ async def extract_entities(
|
||||||
pipeline_status_lock=None,
|
pipeline_status_lock=None,
|
||||||
llm_response_cache: BaseKVStorage | None = None,
|
llm_response_cache: BaseKVStorage | None = None,
|
||||||
text_chunks_storage: BaseKVStorage | None = None,
|
text_chunks_storage: BaseKVStorage | None = None,
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
) -> list:
|
) -> list:
|
||||||
use_llm_func: callable = global_config["llm_model_func"]
|
use_llm_func: callable = global_config["llm_model_func"]
|
||||||
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
||||||
|
|
@ -2150,12 +2156,13 @@ async def extract_entities(
|
||||||
|
|
||||||
final_result, timestamp = await use_llm_func_with_cache(
|
final_result, timestamp = await use_llm_func_with_cache(
|
||||||
entity_extraction_user_prompt,
|
entity_extraction_user_prompt,
|
||||||
use_llm_func,
|
use_llm_func=use_llm_func,
|
||||||
system_prompt=entity_extraction_system_prompt,
|
system_prompt=entity_extraction_system_prompt,
|
||||||
llm_response_cache=llm_response_cache,
|
hashing_kv=llm_response_cache,
|
||||||
cache_type="extract",
|
cache_type="extract",
|
||||||
chunk_id=chunk_key,
|
chunk_id=chunk_key,
|
||||||
cache_keys_collector=cache_keys_collector,
|
cache_keys_collector=cache_keys_collector,
|
||||||
|
token_tracker=token_tracker,
|
||||||
)
|
)
|
||||||
|
|
||||||
history = pack_user_ass_to_openai_messages(
|
history = pack_user_ass_to_openai_messages(
|
||||||
|
|
@ -2177,16 +2184,16 @@ async def extract_entities(
|
||||||
if entity_extract_max_gleaning > 0:
|
if entity_extract_max_gleaning > 0:
|
||||||
glean_result, timestamp = await use_llm_func_with_cache(
|
glean_result, timestamp = await use_llm_func_with_cache(
|
||||||
entity_continue_extraction_user_prompt,
|
entity_continue_extraction_user_prompt,
|
||||||
use_llm_func,
|
use_llm_func=use_llm_func,
|
||||||
system_prompt=entity_extraction_system_prompt,
|
system_prompt=entity_extraction_system_prompt,
|
||||||
llm_response_cache=llm_response_cache,
|
llm_response_cache=llm_response_cache,
|
||||||
history_messages=history,
|
history_messages=history,
|
||||||
cache_type="extract",
|
cache_type="extract",
|
||||||
chunk_id=chunk_key,
|
chunk_id=chunk_key,
|
||||||
cache_keys_collector=cache_keys_collector,
|
cache_keys_collector=cache_keys_collector,
|
||||||
|
token_tracker=token_tracker,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Process gleaning result separately with file path and metadata
|
# Process gleaning result separately with file path and metadata
|
||||||
glean_nodes, glean_edges = await _process_extraction_result(
|
glean_nodes, glean_edges = await _process_extraction_result(
|
||||||
glean_result,
|
glean_result,
|
||||||
|
|
@ -2300,7 +2307,7 @@ async def extract_entities(
|
||||||
await asyncio.wait(pending)
|
await asyncio.wait(pending)
|
||||||
|
|
||||||
# Add progress prefix to the exception message
|
# Add progress prefix to the exception message
|
||||||
progress_prefix = f"C[{processed_chunks+1}/{total_chunks}]"
|
progress_prefix = f"C[{processed_chunks + 1}/{total_chunks}]"
|
||||||
|
|
||||||
# Re-raise the original exception with a prefix
|
# Re-raise the original exception with a prefix
|
||||||
prefixed_exception = create_prefixed_exception(first_exception, progress_prefix)
|
prefixed_exception = create_prefixed_exception(first_exception, progress_prefix)
|
||||||
|
|
@ -2324,6 +2331,7 @@ async def kg_query(
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
chunks_vdb: BaseVectorStorage = None,
|
chunks_vdb: BaseVectorStorage = None,
|
||||||
return_raw_data: Literal[True] = False,
|
return_raw_data: Literal[True] = False,
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
) -> dict[str, Any]: ...
|
) -> dict[str, Any]: ...
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2341,6 +2349,7 @@ async def kg_query(
|
||||||
chunks_vdb: BaseVectorStorage = None,
|
chunks_vdb: BaseVectorStorage = None,
|
||||||
metadata_filters: list | None = None,
|
metadata_filters: list | None = None,
|
||||||
return_raw_data: Literal[False] = False,
|
return_raw_data: Literal[False] = False,
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
) -> str | AsyncIterator[str]: ...
|
) -> str | AsyncIterator[str]: ...
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2355,6 +2364,7 @@ async def kg_query(
|
||||||
hashing_kv: BaseKVStorage | None = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
chunks_vdb: BaseVectorStorage = None,
|
chunks_vdb: BaseVectorStorage = None,
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
) -> QueryResult:
|
) -> QueryResult:
|
||||||
"""
|
"""
|
||||||
Execute knowledge graph query and return unified QueryResult object.
|
Execute knowledge graph query and return unified QueryResult object.
|
||||||
|
|
@ -2422,7 +2432,7 @@ async def kg_query(
|
||||||
return QueryResult(content=cached_response)
|
return QueryResult(content=cached_response)
|
||||||
|
|
||||||
hl_keywords, ll_keywords = await get_keywords_from_query(
|
hl_keywords, ll_keywords = await get_keywords_from_query(
|
||||||
query, query_param, global_config, hashing_kv
|
query, query_param, global_config, hashing_kv, token_tracker
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"High-level keywords: {hl_keywords}")
|
logger.debug(f"High-level keywords: {hl_keywords}")
|
||||||
|
|
@ -2526,6 +2536,7 @@ async def kg_query(
|
||||||
history_messages=query_param.conversation_history,
|
history_messages=query_param.conversation_history,
|
||||||
enable_cot=True,
|
enable_cot=True,
|
||||||
stream=query_param.stream,
|
stream=query_param.stream,
|
||||||
|
token_tracker=token_tracker,
|
||||||
)
|
)
|
||||||
|
|
||||||
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
|
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
|
||||||
|
|
@ -2583,6 +2594,7 @@ async def get_keywords_from_query(
|
||||||
query_param: QueryParam,
|
query_param: QueryParam,
|
||||||
global_config: dict[str, str],
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage | None = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
) -> tuple[list[str], list[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
"""
|
"""
|
||||||
Retrieves high-level and low-level keywords for RAG operations.
|
Retrieves high-level and low-level keywords for RAG operations.
|
||||||
|
|
@ -2605,7 +2617,7 @@ async def get_keywords_from_query(
|
||||||
|
|
||||||
# Extract keywords using extract_keywords_only function which already supports conversation history
|
# Extract keywords using extract_keywords_only function which already supports conversation history
|
||||||
hl_keywords, ll_keywords = await extract_keywords_only(
|
hl_keywords, ll_keywords = await extract_keywords_only(
|
||||||
query, query_param, global_config, hashing_kv
|
query, query_param, global_config, hashing_kv, token_tracker
|
||||||
)
|
)
|
||||||
return hl_keywords, ll_keywords
|
return hl_keywords, ll_keywords
|
||||||
|
|
||||||
|
|
@ -2615,6 +2627,7 @@ async def extract_keywords_only(
|
||||||
param: QueryParam,
|
param: QueryParam,
|
||||||
global_config: dict[str, str],
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage | None = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
) -> tuple[list[str], list[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
"""
|
"""
|
||||||
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
Extract high-level and low-level keywords from the given 'text' using the LLM.
|
||||||
|
|
@ -2668,7 +2681,9 @@ async def extract_keywords_only(
|
||||||
# Apply higher priority (5) to query relation LLM function
|
# Apply higher priority (5) to query relation LLM function
|
||||||
use_model_func = partial(use_model_func, _priority=5)
|
use_model_func = partial(use_model_func, _priority=5)
|
||||||
|
|
||||||
result = await use_model_func(kw_prompt, keyword_extraction=True)
|
result = await use_model_func(
|
||||||
|
kw_prompt, keyword_extraction=True, token_tracker=token_tracker
|
||||||
|
)
|
||||||
|
|
||||||
# 5. Parse out JSON from the LLM response
|
# 5. Parse out JSON from the LLM response
|
||||||
result = remove_think_tags(result)
|
result = remove_think_tags(result)
|
||||||
|
|
@ -2746,7 +2761,10 @@ async def _get_vector_context(
|
||||||
cosine_threshold = chunks_vdb.cosine_better_than_threshold
|
cosine_threshold = chunks_vdb.cosine_better_than_threshold
|
||||||
|
|
||||||
results = await chunks_vdb.query(
|
results = await chunks_vdb.query(
|
||||||
query, top_k=search_top_k, query_embedding=query_embedding, metadata_filter=query_param.metadata_filter
|
query,
|
||||||
|
top_k=search_top_k,
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
metadata_filter=query_param.metadata_filter,
|
||||||
)
|
)
|
||||||
if not results:
|
if not results:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -2763,7 +2781,7 @@ async def _get_vector_context(
|
||||||
"file_path": result.get("file_path", "unknown_source"),
|
"file_path": result.get("file_path", "unknown_source"),
|
||||||
"source_type": "vector", # Mark the source type
|
"source_type": "vector", # Mark the source type
|
||||||
"chunk_id": result.get("id"), # Add chunk_id for deduplication
|
"chunk_id": result.get("id"), # Add chunk_id for deduplication
|
||||||
"metadata": result.get("metadata")
|
"metadata": result.get("metadata"),
|
||||||
}
|
}
|
||||||
valid_chunks.append(chunk_with_metadata)
|
valid_chunks.append(chunk_with_metadata)
|
||||||
|
|
||||||
|
|
@ -3529,8 +3547,9 @@ async def _get_node_data(
|
||||||
f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})"
|
f"Query nodes: {query} (top_k:{query_param.top_k}, cosine:{entities_vdb.cosine_better_than_threshold})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
results = await entities_vdb.query(
|
||||||
results = await entities_vdb.query(query, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter)
|
query, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter
|
||||||
|
)
|
||||||
|
|
||||||
if not len(results):
|
if not len(results):
|
||||||
return [], []
|
return [], []
|
||||||
|
|
@ -3538,7 +3557,6 @@ async def _get_node_data(
|
||||||
# Extract all entity IDs from your results list
|
# Extract all entity IDs from your results list
|
||||||
node_ids = [r["entity_name"] for r in results]
|
node_ids = [r["entity_name"] for r in results]
|
||||||
|
|
||||||
|
|
||||||
# Extract all entity IDs from your results list
|
# Extract all entity IDs from your results list
|
||||||
node_ids = [r["entity_name"] for r in results]
|
node_ids = [r["entity_name"] for r in results]
|
||||||
|
|
||||||
|
|
@ -3810,7 +3828,9 @@ async def _get_edge_data(
|
||||||
f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})"
|
f"Query edges: {keywords} (top_k:{query_param.top_k}, cosine:{relationships_vdb.cosine_better_than_threshold})"
|
||||||
)
|
)
|
||||||
|
|
||||||
results = await relationships_vdb.query(keywords, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter)
|
results = await relationships_vdb.query(
|
||||||
|
keywords, top_k=query_param.top_k, metadata_filter=query_param.metadata_filter
|
||||||
|
)
|
||||||
|
|
||||||
if not len(results):
|
if not len(results):
|
||||||
return [], []
|
return [], []
|
||||||
|
|
@ -4104,6 +4124,7 @@ async def naive_query(
|
||||||
hashing_kv: BaseKVStorage | None = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
return_raw_data: Literal[True] = True,
|
return_raw_data: Literal[True] = True,
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
) -> dict[str, Any]: ...
|
) -> dict[str, Any]: ...
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -4116,6 +4137,7 @@ async def naive_query(
|
||||||
hashing_kv: BaseKVStorage | None = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
return_raw_data: Literal[False] = False,
|
return_raw_data: Literal[False] = False,
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
) -> str | AsyncIterator[str]: ...
|
) -> str | AsyncIterator[str]: ...
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -4126,6 +4148,7 @@ async def naive_query(
|
||||||
global_config: dict[str, str],
|
global_config: dict[str, str],
|
||||||
hashing_kv: BaseKVStorage | None = None,
|
hashing_kv: BaseKVStorage | None = None,
|
||||||
system_prompt: str | None = None,
|
system_prompt: str | None = None,
|
||||||
|
token_tracker: TokenTracker | None = None,
|
||||||
) -> QueryResult:
|
) -> QueryResult:
|
||||||
"""
|
"""
|
||||||
Execute naive query and return unified QueryResult object.
|
Execute naive query and return unified QueryResult object.
|
||||||
|
|
@ -4321,6 +4344,7 @@ async def naive_query(
|
||||||
history_messages=query_param.conversation_history,
|
history_messages=query_param.conversation_history,
|
||||||
enable_cot=True,
|
enable_cot=True,
|
||||||
stream=query_param.stream,
|
stream=query_param.stream,
|
||||||
|
token_tracker=token_tracker,
|
||||||
)
|
)
|
||||||
|
|
||||||
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
|
if hashing_kv and hashing_kv.global_config.get("enable_llm_cache"):
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import weakref
|
||||||
import asyncio
|
import asyncio
|
||||||
import html
|
import html
|
||||||
import csv
|
import csv
|
||||||
|
import contextvars
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import logging.handlers
|
import logging.handlers
|
||||||
|
|
@ -507,6 +508,7 @@ def priority_limit_async_func_call(
|
||||||
task_id,
|
task_id,
|
||||||
args,
|
args,
|
||||||
kwargs,
|
kwargs,
|
||||||
|
ctx,
|
||||||
) = await asyncio.wait_for(queue.get(), timeout=1.0)
|
) = await asyncio.wait_for(queue.get(), timeout=1.0)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
|
@ -536,11 +538,15 @@ def priority_limit_async_func_call(
|
||||||
try:
|
try:
|
||||||
# Execute function with timeout protection
|
# Execute function with timeout protection
|
||||||
if max_execution_timeout is not None:
|
if max_execution_timeout is not None:
|
||||||
|
# Run the function in the captured context
|
||||||
|
task = ctx.run(lambda: asyncio.create_task(func(*args, **kwargs)))
|
||||||
result = await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
func(*args, **kwargs), timeout=max_execution_timeout
|
task, timeout=max_execution_timeout
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = await func(*args, **kwargs)
|
# Run the function in the captured context
|
||||||
|
task = ctx.run(lambda: asyncio.create_task(func(*args, **kwargs)))
|
||||||
|
result = await task
|
||||||
|
|
||||||
# Set result if future is still valid
|
# Set result if future is still valid
|
||||||
if not task_state.future.done():
|
if not task_state.future.done():
|
||||||
|
|
@ -791,6 +797,9 @@ def priority_limit_async_func_call(
|
||||||
future=future, start_time=asyncio.get_event_loop().time()
|
future=future, start_time=asyncio.get_event_loop().time()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Capture current context
|
||||||
|
ctx = contextvars.copy_context()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Register task state
|
# Register task state
|
||||||
async with task_states_lock:
|
async with task_states_lock:
|
||||||
|
|
@ -809,13 +818,13 @@ def priority_limit_async_func_call(
|
||||||
if _queue_timeout is not None:
|
if _queue_timeout is not None:
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
queue.put(
|
queue.put(
|
||||||
(_priority, current_count, task_id, args, kwargs)
|
(_priority, current_count, task_id, args, kwargs, ctx)
|
||||||
),
|
),
|
||||||
timeout=_queue_timeout,
|
timeout=_queue_timeout,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await queue.put(
|
await queue.put(
|
||||||
(_priority, current_count, task_id, args, kwargs)
|
(_priority, current_count, task_id, args, kwargs, ctx)
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
raise QueueFullError(
|
raise QueueFullError(
|
||||||
|
|
@ -1472,8 +1481,7 @@ async def aexport_data(
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported file format: {file_format}. "
|
f"Unsupported file format: {file_format}. Choose from: csv, excel, md, txt"
|
||||||
f"Choose from: csv, excel, md, txt"
|
|
||||||
)
|
)
|
||||||
if file_format is not None:
|
if file_format is not None:
|
||||||
print(f"Data exported to: {output_path} with format: {file_format}")
|
print(f"Data exported to: {output_path} with format: {file_format}")
|
||||||
|
|
@ -1601,6 +1609,8 @@ async def use_llm_func_with_cache(
|
||||||
cache_type: str = "extract",
|
cache_type: str = "extract",
|
||||||
chunk_id: str | None = None,
|
chunk_id: str | None = None,
|
||||||
cache_keys_collector: list = None,
|
cache_keys_collector: list = None,
|
||||||
|
hashing_kv: "BaseKVStorage | None" = None,
|
||||||
|
token_tracker=None,
|
||||||
) -> tuple[str, int]:
|
) -> tuple[str, int]:
|
||||||
"""Call LLM function with cache support and text sanitization
|
"""Call LLM function with cache support and text sanitization
|
||||||
|
|
||||||
|
|
@ -1685,7 +1695,10 @@ async def use_llm_func_with_cache(
|
||||||
kwargs["max_tokens"] = max_tokens
|
kwargs["max_tokens"] = max_tokens
|
||||||
|
|
||||||
res: str = await use_llm_func(
|
res: str = await use_llm_func(
|
||||||
safe_user_prompt, system_prompt=safe_system_prompt, **kwargs
|
safe_user_prompt,
|
||||||
|
system_prompt=safe_system_prompt,
|
||||||
|
token_tracker=token_tracker,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
res = remove_think_tags(res)
|
res = remove_think_tags(res)
|
||||||
|
|
@ -1720,7 +1733,10 @@ async def use_llm_func_with_cache(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
res = await use_llm_func(
|
res = await use_llm_func(
|
||||||
safe_user_prompt, system_prompt=safe_system_prompt, **kwargs
|
safe_user_prompt,
|
||||||
|
system_prompt=safe_system_prompt,
|
||||||
|
token_tracker=token_tracker,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Add [LLM func] prefix to error message
|
# Add [LLM func] prefix to error message
|
||||||
|
|
@ -2216,52 +2232,74 @@ async def pick_by_vector_similarity(
|
||||||
return all_chunk_ids[:num_of_chunks]
|
return all_chunk_ids[:num_of_chunks]
|
||||||
|
|
||||||
|
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
|
||||||
class TokenTracker:
|
class TokenTracker:
|
||||||
"""Track token usage for LLM calls."""
|
"""Track token usage for LLM calls using ContextVars for concurrency support."""
|
||||||
|
|
||||||
|
_usage_var: ContextVar[dict] = ContextVar("token_usage", default=None)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
# No instance state needed as we use ContextVar
|
||||||
|
pass
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
print(self)
|
# Optional: Log usage on exit if needed
|
||||||
|
pass
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.prompt_tokens = 0
|
"""Initialize/Reset token usage for the current context."""
|
||||||
self.completion_tokens = 0
|
self._usage_var.set(
|
||||||
self.total_tokens = 0
|
{
|
||||||
self.call_count = 0
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0,
|
||||||
|
"call_count": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def add_usage(self, token_counts):
|
def _get_current_usage(self) -> dict:
|
||||||
|
"""Get the usage dict for the current context, initializing if necessary."""
|
||||||
|
usage = self._usage_var.get()
|
||||||
|
if usage is None:
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"completion_tokens": 0,
|
||||||
|
"total_tokens": 0,
|
||||||
|
"call_count": 0,
|
||||||
|
}
|
||||||
|
self._usage_var.set(usage)
|
||||||
|
return usage
|
||||||
|
|
||||||
|
def add_usage(self, token_counts: dict):
|
||||||
"""Add token usage from one LLM call.
|
"""Add token usage from one LLM call.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens
|
token_counts: A dictionary containing prompt_tokens, completion_tokens, total_tokens
|
||||||
"""
|
"""
|
||||||
self.prompt_tokens += token_counts.get("prompt_tokens", 0)
|
usage = self._get_current_usage()
|
||||||
self.completion_tokens += token_counts.get("completion_tokens", 0)
|
|
||||||
|
usage["prompt_tokens"] += token_counts.get("prompt_tokens", 0)
|
||||||
|
usage["completion_tokens"] += token_counts.get("completion_tokens", 0)
|
||||||
|
|
||||||
# If total_tokens is provided, use it directly; otherwise calculate the sum
|
# If total_tokens is provided, use it directly; otherwise calculate the sum
|
||||||
if "total_tokens" in token_counts:
|
if "total_tokens" in token_counts:
|
||||||
self.total_tokens += token_counts["total_tokens"]
|
usage["total_tokens"] += token_counts["total_tokens"]
|
||||||
else:
|
else:
|
||||||
self.total_tokens += token_counts.get(
|
usage["total_tokens"] += token_counts.get(
|
||||||
"prompt_tokens", 0
|
"prompt_tokens", 0
|
||||||
) + token_counts.get("completion_tokens", 0)
|
) + token_counts.get("completion_tokens", 0)
|
||||||
|
|
||||||
self.call_count += 1
|
usage["call_count"] += 1
|
||||||
|
|
||||||
def get_usage(self):
|
def get_usage(self):
|
||||||
"""Get current usage statistics."""
|
"""Get current usage statistics."""
|
||||||
return {
|
return self._get_current_usage().copy()
|
||||||
"prompt_tokens": self.prompt_tokens,
|
|
||||||
"completion_tokens": self.completion_tokens,
|
|
||||||
"total_tokens": self.total_tokens,
|
|
||||||
"call_count": self.call_count,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
usage = self.get_usage()
|
usage = self.get_usage()
|
||||||
|
|
@ -2273,6 +2311,26 @@ class TokenTracker:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_embedding_tokens(texts: list[str], tokenizer: Tokenizer) -> int:
|
||||||
|
"""Estimate tokens for embedding operations based on text length.
|
||||||
|
|
||||||
|
Most embedding APIs don't return token counts, so we estimate based on
|
||||||
|
the tokenizer encoding. This provides a reasonable approximation for tracking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of text strings to be embedded
|
||||||
|
tokenizer: Tokenizer instance for encoding
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated total token count for all texts
|
||||||
|
"""
|
||||||
|
total = 0
|
||||||
|
for text in texts:
|
||||||
|
if text: # Skip empty strings
|
||||||
|
total += len(tokenizer.encode(text))
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
async def apply_rerank_if_enabled(
|
async def apply_rerank_if_enabled(
|
||||||
query: str,
|
query: str,
|
||||||
retrieved_docs: list[dict],
|
retrieved_docs: list[dict],
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue