refactor: Modify sys.path in context manager

This commit is contained in:
lxobr 2024-11-15 17:59:10 +01:00
parent ba83d71269
commit e148d32c14

View file

@ -1,3 +1,4 @@
from contextlib import contextmanager
from typing import List, Dict, Optional from typing import List, Dict, Optional
import jedi import jedi
import parso import parso
@ -6,7 +7,17 @@ from pathlib import Path
from parso.tree import BaseNode from parso.tree import BaseNode
def get_code_entities(node: parso.tree.NodeOrLeaf) -> List[Dict[str, any]]: @contextmanager
def add_sys_path(path):
original_sys_path = sys.path.copy()
sys.path.insert(0, path)
try:
yield
finally:
sys.path = original_sys_path
def _get_code_entities(node: parso.tree.NodeOrLeaf) -> List[Dict[str, any]]:
""" """
Recursively extract code entities using parso. Recursively extract code entities using parso.
""" """
@ -27,12 +38,12 @@ def get_code_entities(node: parso.tree.NodeOrLeaf) -> List[Dict[str, any]]:
# Recursively process child nodes # Recursively process child nodes
for child in node.children: for child in node.children:
code_entity_list.extend(get_code_entities(child)) code_entity_list.extend(_get_code_entities(child))
return code_entity_list return code_entity_list
def update_code_entity(script: jedi.Script, code_entity: Dict[str, any]) -> None: def _update_code_entity(script: jedi.Script, code_entity: Dict[str, any]) -> None:
""" """
Update a code_entity with (full_name, module_name, module_path) using Jedi Update a code_entity with (full_name, module_name, module_path) using Jedi
""" """
@ -42,35 +53,38 @@ def update_code_entity(script: jedi.Script, code_entity: Dict[str, any]) -> None
code_entity["module_name"] = getattr(results[0], "module_name", None) code_entity["module_name"] = getattr(results[0], "module_name", None)
code_entity["module_path"] = getattr(results[0], "module_path", None) code_entity["module_path"] = getattr(results[0], "module_path", None)
def _extract_dependencies(script_path: str) -> List[str]:
def get_local_script_dependencies(script_path: str, repo_path: Optional[str] = None) -> List[str]:
"""
Extract and return a list of unique module paths that the script depends on.
"""
if repo_path:
sys.path.insert(0, str(Path(repo_path).resolve()))
with open(script_path, "r") as file: with open(script_path, "r") as file:
source_code = file.read() source_code = file.read()
script = jedi.Script(code=source_code, path=script_path) script = jedi.Script(code=source_code, path=script_path)
tree = parso.parse(source_code) tree = parso.parse(source_code)
code_entities = get_code_entities(tree) code_entities = _get_code_entities(tree)
for code_entity in code_entities: for code_entity in code_entities:
update_code_entity(script, code_entity) _update_code_entity(script, code_entity)
module_paths = { module_paths = {
entity.get("module_path") entity.get("module_path")
for entity in code_entities for entity in code_entities
if entity.get("module_path") if entity.get("module_path")
} }
if repo_path:
repo_path_resolved = str(Path(repo_path).resolve(strict=False))
module_paths = {path for path in module_paths if str(path).startswith(repo_path_resolved)}
return sorted(path for path in module_paths if path) return sorted(str(path) for path in module_paths)
def get_local_script_dependencies(script_path: str, repo_path: Optional[str] = None) -> List[str]:
"""
Extract and return a list of unique module paths that the script depends on.
"""
if repo_path:
repo_path_resolved = str(Path(repo_path).resolve())
with add_sys_path(repo_path_resolved):
dependencies = _extract_dependencies(script_path)
dependencies = [path for path in dependencies if path.startswith(repo_path_resolved)]
else:
dependencies = _extract_dependencies(script_path)
return dependencies
if __name__ == "__main__": if __name__ == "__main__":
# Simple execution example, use absolute paths # Simple execution example, use absolute paths