refactor: Modify sys.path in context manager

This commit is contained in:
lxobr 2024-11-15 17:59:10 +01:00
parent ba83d71269
commit e148d32c14

View file

@ -1,3 +1,4 @@
from contextlib import contextmanager
from typing import List, Dict, Optional
import jedi
import parso
@ -6,7 +7,17 @@ from pathlib import Path
from parso.tree import BaseNode
def get_code_entities(node: parso.tree.NodeOrLeaf) -> List[Dict[str, any]]:
@contextmanager
def add_sys_path(path):
original_sys_path = sys.path.copy()
sys.path.insert(0, path)
try:
yield
finally:
sys.path = original_sys_path
def _get_code_entities(node: parso.tree.NodeOrLeaf) -> List[Dict[str, any]]:
"""
Recursively extract code entities using parso.
"""
@ -27,12 +38,12 @@ def get_code_entities(node: parso.tree.NodeOrLeaf) -> List[Dict[str, any]]:
# Recursively process child nodes
for child in node.children:
code_entity_list.extend(get_code_entities(child))
code_entity_list.extend(_get_code_entities(child))
return code_entity_list
def update_code_entity(script: jedi.Script, code_entity: Dict[str, any]) -> None:
def _update_code_entity(script: jedi.Script, code_entity: Dict[str, any]) -> None:
"""
Update a code_entity with (full_name, module_name, module_path) using Jedi
"""
@ -42,35 +53,38 @@ def update_code_entity(script: jedi.Script, code_entity: Dict[str, any]) -> None
code_entity["module_name"] = getattr(results[0], "module_name", None)
code_entity["module_path"] = getattr(results[0], "module_path", None)
def get_local_script_dependencies(script_path: str, repo_path: Optional[str] = None) -> List[str]:
"""
Extract and return a list of unique module paths that the script depends on.
"""
if repo_path:
sys.path.insert(0, str(Path(repo_path).resolve()))
def _extract_dependencies(script_path: str) -> List[str]:
with open(script_path, "r") as file:
source_code = file.read()
script = jedi.Script(code=source_code, path=script_path)
tree = parso.parse(source_code)
code_entities = get_code_entities(tree)
code_entities = _get_code_entities(tree)
for code_entity in code_entities:
update_code_entity(script, code_entity)
_update_code_entity(script, code_entity)
module_paths = {
entity.get("module_path")
for entity in code_entities
if entity.get("module_path")
}
if repo_path:
repo_path_resolved = str(Path(repo_path).resolve(strict=False))
module_paths = {path for path in module_paths if str(path).startswith(repo_path_resolved)}
return sorted(path for path in module_paths if path)
return sorted(str(path) for path in module_paths)
def get_local_script_dependencies(script_path: str, repo_path: Optional[str] = None) -> List[str]:
"""
Extract and return a list of unique module paths that the script depends on.
"""
if repo_path:
repo_path_resolved = str(Path(repo_path).resolve())
with add_sys_path(repo_path_resolved):
dependencies = _extract_dependencies(script_path)
dependencies = [path for path in dependencies if path.startswith(repo_path_resolved)]
else:
dependencies = _extract_dependencies(script_path)
return dependencies
if __name__ == "__main__":
# Simple execution example, use absolute paths