Compare commits
1 commit
main
...
pensar-aut
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c227fad2fb |
1 changed files with 37 additions and 3 deletions
|
|
@ -181,12 +181,46 @@ def retrieved_edges_to_string(search_results):
|
||||||
|
|
||||||
|
|
||||||
def load_class(model_file, model_name):
|
def load_class(model_file, model_name):
|
||||||
model_file = os.path.abspath(model_file)
|
"""
|
||||||
spec = importlib.util.spec_from_file_location("graph_model", model_file)
|
Securely loads a class from the trusted models directory.
|
||||||
|
|
||||||
|
Only allows loading Python files under the 'cognee/modules/models/' directory.
|
||||||
|
The model name must be a valid Python identifier, refer to a class, and be defined in the module.
|
||||||
|
"""
|
||||||
|
# Define the allowed directory (patched: only allow loading models from this trusted location)
|
||||||
|
BASE_MODEL_DIR = os.path.abspath(
|
||||||
|
os.path.join(os.path.dirname(__file__), "cognee", "modules", "models")
|
||||||
|
)
|
||||||
|
|
||||||
|
abs_model_file = os.path.abspath(model_file)
|
||||||
|
|
||||||
|
# Check that the file is within the allowed directory (prevent path traversal and arbitrary locations)
|
||||||
|
if not abs_model_file.startswith(BASE_MODEL_DIR + os.sep):
|
||||||
|
raise ValueError("Model file must be located within the trusted models directory.")
|
||||||
|
|
||||||
|
# File must end with .py
|
||||||
|
if not abs_model_file.endswith(".py"):
|
||||||
|
raise ValueError("Model file must be a Python (.py) file.")
|
||||||
|
|
||||||
|
# File must exist
|
||||||
|
if not os.path.isfile(abs_model_file):
|
||||||
|
raise ValueError("Model file does not exist.")
|
||||||
|
|
||||||
|
# Validate class name: must be identifier
|
||||||
|
if not model_name or not model_name.isidentifier():
|
||||||
|
raise ValueError("Model class name must be a valid Python identifier.")
|
||||||
|
|
||||||
|
# Load module as before, from absolute, trusted, validated path
|
||||||
|
spec = importlib.util.spec_from_file_location("graph_model", abs_model_file)
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec)
|
||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
# Ensure the model_name exists and is a class in the module
|
||||||
|
if not hasattr(module, model_name):
|
||||||
|
raise ValueError(f"Model class '{model_name}' not found in file.")
|
||||||
model_class = getattr(module, model_name)
|
model_class = getattr(module, model_name)
|
||||||
|
if not isinstance(model_class, type):
|
||||||
|
raise ValueError(f"Attribute '{model_name}' is not a class.")
|
||||||
|
|
||||||
return model_class
|
return model_class
|
||||||
|
|
||||||
|
|
@ -218,4 +252,4 @@ if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error initializing Cognee MCP server: {str(e)}")
|
logger.error(f"Error initializing Cognee MCP server: {str(e)}")
|
||||||
raise
|
raise
|
||||||
Loading…
Add table
Reference in a new issue