Compare commits
1 commit
main
...
pensar-aut
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c227fad2fb |
1 changed files with 37 additions and 3 deletions
|
|
@ -181,12 +181,46 @@ def retrieved_edges_to_string(search_results):
|
|||
|
||||
|
||||
def load_class(model_file, model_name):
|
||||
model_file = os.path.abspath(model_file)
|
||||
spec = importlib.util.spec_from_file_location("graph_model", model_file)
|
||||
"""
|
||||
Securely loads a class from the trusted models directory.
|
||||
|
||||
Only allows loading Python files under the 'cognee/modules/models/' directory.
|
||||
The model name must be a valid Python identifier, refer to a class, and be defined in the module.
|
||||
"""
|
||||
# Define the allowed directory (patched: only allow loading models from this trusted location)
|
||||
BASE_MODEL_DIR = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "cognee", "modules", "models")
|
||||
)
|
||||
|
||||
abs_model_file = os.path.abspath(model_file)
|
||||
|
||||
# Check that the file is within the allowed directory (prevent path traversal and arbitrary locations)
|
||||
if not abs_model_file.startswith(BASE_MODEL_DIR + os.sep):
|
||||
raise ValueError("Model file must be located within the trusted models directory.")
|
||||
|
||||
# File must end with .py
|
||||
if not abs_model_file.endswith(".py"):
|
||||
raise ValueError("Model file must be a Python (.py) file.")
|
||||
|
||||
# File must exist
|
||||
if not os.path.isfile(abs_model_file):
|
||||
raise ValueError("Model file does not exist.")
|
||||
|
||||
# Validate class name: must be identifier
|
||||
if not model_name or not model_name.isidentifier():
|
||||
raise ValueError("Model class name must be a valid Python identifier.")
|
||||
|
||||
# Load module as before, from absolute, trusted, validated path
|
||||
spec = importlib.util.spec_from_file_location("graph_model", abs_model_file)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Ensure the model_name exists and is a class in the module
|
||||
if not hasattr(module, model_name):
|
||||
raise ValueError(f"Model class '{model_name}' not found in file.")
|
||||
model_class = getattr(module, model_name)
|
||||
if not isinstance(model_class, type):
|
||||
raise ValueError(f"Attribute '{model_name}' is not a class.")
|
||||
|
||||
return model_class
|
||||
|
||||
|
|
@ -218,4 +252,4 @@ if __name__ == "__main__":
|
|||
asyncio.run(main())
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing Cognee MCP server: {str(e)}")
|
||||
raise
|
||||
raise
|
||||
Loading…
Add table
Reference in a new issue