From d054ec5d0016a3a76f8aa98f2bc670ac14fb05ff Mon Sep 17 00:00:00 2001 From: Thibo Rosemplatt <38402230+thiborose@users.noreply.github.com> Date: Sat, 23 Aug 2025 20:16:11 +0200 Subject: [PATCH] Added entity_types as a user defined variable (via .env) --- env.example | 2 ++ lightrag/api/config.py | 2 ++ lightrag/api/lightrag_server.py | 4 ++-- lightrag/api/utils_api.py | 2 ++ lightrag/constants.py | 1 + lightrag/lightrag.py | 4 +++- lightrag/operate.py | 3 ++- lightrag/prompt.py | 2 -- 8 files changed, 14 insertions(+), 6 deletions(-) diff --git a/env.example b/env.example index 41c77ede..58117a6e 100644 --- a/env.example +++ b/env.example @@ -131,6 +131,8 @@ ENABLE_LLM_CACHE_FOR_EXTRACT=true # SUMMARY_MAX_TOKENS=30000 ### Maximum number of entity extraction attempts for ambiguous content # MAX_GLEANING=1 +### Customize the entities that the LLM will attempt to recognize +# ENTITY_TYPES=["person", "organization", "location", "event", "concept"] ############################### ### Concurrency Configuration diff --git a/lightrag/api/config.py b/lightrag/api/config.py index a5e352dc..891491e2 100644 --- a/lightrag/api/config.py +++ b/lightrag/api/config.py @@ -36,6 +36,7 @@ from lightrag.constants import ( DEFAULT_OLLAMA_MODEL_NAME, DEFAULT_OLLAMA_MODEL_TAG, DEFAULT_RERANK_BINDING, + DEFAULT_ENTITY_TYPES ) # use the .env that is inside the current folder @@ -333,6 +334,7 @@ def parse_args() -> argparse.Namespace: # Add environment variables that were previously read directly args.cors_origins = get_env_value("CORS_ORIGINS", "*") args.summary_language = get_env_value("SUMMARY_LANGUAGE", DEFAULT_SUMMARY_LANGUAGE) + args.entity_types = get_env_value("ENTITY_TYPES", DEFAULT_ENTITY_TYPES) args.whitelist_paths = get_env_value("WHITELIST_PATHS", "/health,/api/*") # For JWT Auth diff --git a/lightrag/api/lightrag_server.py b/lightrag/api/lightrag_server.py index 328e7953..26ed5d0b 100644 --- a/lightrag/api/lightrag_server.py +++ b/lightrag/api/lightrag_server.py @@ -497,7 +497,7 @@ def create_app(args): rerank_model_func=rerank_model_func, max_parallel_insert=args.max_parallel_insert, max_graph_nodes=args.max_graph_nodes, - addon_params={"language": args.summary_language}, + addon_params={"language": args.summary_language, "entity_types": args.entity_types}, ollama_server_infos=ollama_server_infos, ) else: # azure_openai @@ -523,7 +523,7 @@ def create_app(args): rerank_model_func=rerank_model_func, max_parallel_insert=args.max_parallel_insert, max_graph_nodes=args.max_graph_nodes, - addon_params={"language": args.summary_language}, + addon_params={"language": args.summary_language, "entity_types": args.entity_types}, ollama_server_infos=ollama_server_infos, ) diff --git a/lightrag/api/utils_api.py b/lightrag/api/utils_api.py index fc05716c..eca3799f 100644 --- a/lightrag/api/utils_api.py +++ b/lightrag/api/utils_api.py @@ -264,6 +264,8 @@ def display_splash_screen(args: argparse.Namespace) -> None: ASCIIColors.magenta("\n⚙️ RAG Configuration:") ASCIIColors.white(" ├─ Summary Language: ", end="") ASCIIColors.yellow(f"{args.summary_language}") + ASCIIColors.white(" ├─ Entity Types: ", end="") + ASCIIColors.yellow(f"{args.entity_types}") ASCIIColors.white(" ├─ Max Parallel Insert: ", end="") ASCIIColors.yellow(f"{args.max_parallel_insert}") ASCIIColors.white(" ├─ Chunk Size: ", end="") diff --git a/lightrag/constants.py b/lightrag/constants.py index 9445872e..a1b2ba46 100644 --- a/lightrag/constants.py +++ b/lightrag/constants.py @@ -15,6 +15,7 @@ DEFAULT_SUMMARY_LANGUAGE = "English" # Default language for summaries DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 4 DEFAULT_MAX_GLEANING = 1 DEFAULT_SUMMARY_MAX_TOKENS = 30000 # Default maximum token size +DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"] # Separator for graph fields GRAPH_FIELD_SEP = "" diff --git a/lightrag/lightrag.py b/lightrag/lightrag.py index 721181d5..f8b6f7fc 100644 --- a/lightrag/lightrag.py +++ b/lightrag/lightrag.py @@ -37,6 +37,7 @@ from lightrag.constants import ( DEFAULT_MAX_ASYNC, DEFAULT_MAX_PARALLEL_INSERT, DEFAULT_MAX_GRAPH_NODES, + DEFAULT_ENTITY_TYPES ) from lightrag.utils import get_env_value @@ -333,7 +334,8 @@ class LightRAG: addon_params: dict[str, Any] = field( default_factory=lambda: { - "language": get_env_value("SUMMARY_LANGUAGE", "English", str) + "language": get_env_value("SUMMARY_LANGUAGE", "English", str), + "entity_types": get_env_value("ENTITY_TYPES", DEFAULT_ENTITY_TYPES, list), } ) diff --git a/lightrag/operate.py b/lightrag/operate.py index 1e1ccbb6..8ae51f41 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -47,6 +47,7 @@ from .constants import ( DEFAULT_MAX_TOTAL_TOKENS, DEFAULT_RELATED_CHUNK_NUMBER, DEFAULT_KG_CHUNK_PICK_METHOD, + DEFAULT_ENTITY_TYPES ) from .kg.shared_storage import get_storage_keyed_lock import time @@ -1487,7 +1488,7 @@ async def extract_entities( "language", PROMPTS["DEFAULT_LANGUAGE"] ) entity_types = global_config["addon_params"].get( - "entity_types", PROMPTS["DEFAULT_ENTITY_TYPES"] + "entity_types", DEFAULT_ENTITY_TYPES ) example_number = global_config["addon_params"].get("example_number", None) if example_number and example_number < len(PROMPTS["entity_extraction_examples"]): diff --git a/lightrag/prompt.py b/lightrag/prompt.py index 32666bb5..b315f445 100644 --- a/lightrag/prompt.py +++ b/lightrag/prompt.py @@ -9,8 +9,6 @@ PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>" PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##" PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>" -PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"] - PROMPTS["DEFAULT_USER_PROMPT"] = "n/a" PROMPTS["entity_extraction"] = """---Goal---