Added entity_types as a user defined variable (via .env)
This commit is contained in:
parent
4e79550725
commit
d054ec5d00
8 changed files with 14 additions and 6 deletions
|
|
@ -131,6 +131,8 @@ ENABLE_LLM_CACHE_FOR_EXTRACT=true
|
|||
# SUMMARY_MAX_TOKENS=30000
|
||||
### Maximum number of entity extraction attempts for ambiguous content
|
||||
# MAX_GLEANING=1
|
||||
### Customize the entities that the LLM will attempt to recognize
|
||||
# ENTITY_TYPES=["person", "organization", "location", "event", "concept"]
|
||||
|
||||
###############################
|
||||
### Concurrency Configuration
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ from lightrag.constants import (
|
|||
DEFAULT_OLLAMA_MODEL_NAME,
|
||||
DEFAULT_OLLAMA_MODEL_TAG,
|
||||
DEFAULT_RERANK_BINDING,
|
||||
DEFAULT_ENTITY_TYPES
|
||||
)
|
||||
|
||||
# use the .env that is inside the current folder
|
||||
|
|
@ -333,6 +334,7 @@ def parse_args() -> argparse.Namespace:
|
|||
# Add environment variables that were previously read directly
|
||||
args.cors_origins = get_env_value("CORS_ORIGINS", "*")
|
||||
args.summary_language = get_env_value("SUMMARY_LANGUAGE", DEFAULT_SUMMARY_LANGUAGE)
|
||||
args.entity_types = get_env_value("ENTITY_TYPES", DEFAULT_ENTITY_TYPES)
|
||||
args.whitelist_paths = get_env_value("WHITELIST_PATHS", "/health,/api/*")
|
||||
|
||||
# For JWT Auth
|
||||
|
|
|
|||
|
|
@ -497,7 +497,7 @@ def create_app(args):
|
|||
rerank_model_func=rerank_model_func,
|
||||
max_parallel_insert=args.max_parallel_insert,
|
||||
max_graph_nodes=args.max_graph_nodes,
|
||||
addon_params={"language": args.summary_language},
|
||||
addon_params={"language": args.summary_language, "entity_types": args.entity_types},
|
||||
ollama_server_infos=ollama_server_infos,
|
||||
)
|
||||
else: # azure_openai
|
||||
|
|
@ -523,7 +523,7 @@ def create_app(args):
|
|||
rerank_model_func=rerank_model_func,
|
||||
max_parallel_insert=args.max_parallel_insert,
|
||||
max_graph_nodes=args.max_graph_nodes,
|
||||
addon_params={"language": args.summary_language},
|
||||
addon_params={"language": args.summary_language, "entity_types": args.entity_types},
|
||||
ollama_server_infos=ollama_server_infos,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -264,6 +264,8 @@ def display_splash_screen(args: argparse.Namespace) -> None:
|
|||
ASCIIColors.magenta("\n⚙️ RAG Configuration:")
|
||||
ASCIIColors.white(" ├─ Summary Language: ", end="")
|
||||
ASCIIColors.yellow(f"{args.summary_language}")
|
||||
ASCIIColors.white(" ├─ Entity Types: ", end="")
|
||||
ASCIIColors.yellow(f"{args.entity_types}")
|
||||
ASCIIColors.white(" ├─ Max Parallel Insert: ", end="")
|
||||
ASCIIColors.yellow(f"{args.max_parallel_insert}")
|
||||
ASCIIColors.white(" ├─ Chunk Size: ", end="")
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ DEFAULT_SUMMARY_LANGUAGE = "English" # Default language for summaries
|
|||
DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE = 4
|
||||
DEFAULT_MAX_GLEANING = 1
|
||||
DEFAULT_SUMMARY_MAX_TOKENS = 30000 # Default maximum token size
|
||||
DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event", "category"]
|
||||
|
||||
# Separator for graph fields
|
||||
GRAPH_FIELD_SEP = "<SEP>"
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from lightrag.constants import (
|
|||
DEFAULT_MAX_ASYNC,
|
||||
DEFAULT_MAX_PARALLEL_INSERT,
|
||||
DEFAULT_MAX_GRAPH_NODES,
|
||||
DEFAULT_ENTITY_TYPES
|
||||
)
|
||||
from lightrag.utils import get_env_value
|
||||
|
||||
|
|
@ -333,7 +334,8 @@ class LightRAG:
|
|||
|
||||
addon_params: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"language": get_env_value("SUMMARY_LANGUAGE", "English", str)
|
||||
"language": get_env_value("SUMMARY_LANGUAGE", "English", str),
|
||||
"entity_types": get_env_value("ENTITY_TYPES", DEFAULT_ENTITY_TYPES, list),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ from .constants import (
|
|||
DEFAULT_MAX_TOTAL_TOKENS,
|
||||
DEFAULT_RELATED_CHUNK_NUMBER,
|
||||
DEFAULT_KG_CHUNK_PICK_METHOD,
|
||||
DEFAULT_ENTITY_TYPES
|
||||
)
|
||||
from .kg.shared_storage import get_storage_keyed_lock
|
||||
import time
|
||||
|
|
@ -1487,7 +1488,7 @@ async def extract_entities(
|
|||
"language", PROMPTS["DEFAULT_LANGUAGE"]
|
||||
)
|
||||
entity_types = global_config["addon_params"].get(
|
||||
"entity_types", PROMPTS["DEFAULT_ENTITY_TYPES"]
|
||||
"entity_types", DEFAULT_ENTITY_TYPES
|
||||
)
|
||||
example_number = global_config["addon_params"].get("example_number", None)
|
||||
if example_number and example_number < len(PROMPTS["entity_extraction_examples"]):
|
||||
|
|
|
|||
|
|
@ -9,8 +9,6 @@ PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
|
|||
PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
|
||||
PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
|
||||
|
||||
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"]
|
||||
|
||||
PROMPTS["DEFAULT_USER_PROMPT"] = "n/a"
|
||||
|
||||
PROMPTS["entity_extraction"] = """---Goal---
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue