Compare commits
1 commit
main
...
pensar-aut
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b9dd2bd916 |
1 changed files with 40 additions and 5 deletions
|
|
@ -1,6 +1,34 @@
|
|||
import os
|
||||
from cognee.infrastructure.llm.get_llm_client import get_llm_client
|
||||
from cognee.infrastructure.llm.prompts import read_query_prompt, render_prompt
|
||||
|
||||
# Define the directory where prompt templates are allowed to reside.
|
||||
PROMPTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompts')
|
||||
# Define the set of allowed prompt filenames. Extend as needed.
|
||||
ALLOWED_PROMPT_FILENAMES = {
|
||||
"summarize_search_results.txt",
|
||||
# Add other allowed prompt template files here.
|
||||
}
|
||||
|
||||
def validate_prompt_path(prompt_path: str) -> str:
|
||||
"""
|
||||
Validates the prompt path to prevent path traversal and local file inclusion.
|
||||
Only allows files within PROMPTS_DIR and with an allowed filename.
|
||||
Returns the cleaned absolute path to the prompt file if valid, raises ValueError otherwise.
|
||||
"""
|
||||
# Only allow filenames (no directory component)
|
||||
filename = os.path.basename(prompt_path)
|
||||
|
||||
# Check for allowed filenames
|
||||
if filename not in ALLOWED_PROMPT_FILENAMES:
|
||||
raise ValueError(f"Invalid prompt filename: {filename}")
|
||||
|
||||
# Construct absolute path to file in prompts directory
|
||||
abs_path = os.path.abspath(os.path.join(PROMPTS_DIR, filename))
|
||||
# Ensure the path is within the prompts directory
|
||||
if not abs_path.startswith(os.path.abspath(PROMPTS_DIR) + os.sep):
|
||||
raise ValueError("Attempted path traversal in prompt path.")
|
||||
return abs_path
|
||||
|
||||
async def generate_completion(
|
||||
query: str,
|
||||
|
|
@ -10,8 +38,13 @@ async def generate_completion(
|
|||
) -> str:
|
||||
"""Generates a completion using LLM with given context and prompts."""
|
||||
args = {"question": query, "context": context}
|
||||
user_prompt = render_prompt(user_prompt_path, args)
|
||||
system_prompt = read_query_prompt(system_prompt_path)
|
||||
|
||||
# Validate prompt paths
|
||||
user_prompt_file = validate_prompt_path(user_prompt_path)
|
||||
system_prompt_file = validate_prompt_path(system_prompt_path)
|
||||
|
||||
user_prompt = render_prompt(user_prompt_file, args)
|
||||
system_prompt = read_query_prompt(system_prompt_file)
|
||||
|
||||
llm_client = get_llm_client()
|
||||
return await llm_client.acreate_structured_output(
|
||||
|
|
@ -20,17 +53,19 @@ async def generate_completion(
|
|||
response_model=str,
|
||||
)
|
||||
|
||||
|
||||
async def summarize_text(
|
||||
text: str,
|
||||
prompt_path: str = "summarize_search_results.txt",
|
||||
) -> str:
|
||||
"""Summarizes text using LLM with the specified prompt."""
|
||||
system_prompt = read_query_prompt(prompt_path)
|
||||
# Validate prompt path
|
||||
prompt_file = validate_prompt_path(prompt_path)
|
||||
|
||||
system_prompt = read_query_prompt(prompt_file)
|
||||
llm_client = get_llm_client()
|
||||
|
||||
return await llm_client.acreate_structured_output(
|
||||
text_input=text,
|
||||
system_prompt=system_prompt,
|
||||
response_model=str,
|
||||
)
|
||||
)
|
||||
Loading…
Add table
Reference in a new issue