refactor: Modify sys.path in context manager
This commit is contained in:
parent
ba83d71269
commit
e148d32c14
1 changed files with 31 additions and 17 deletions
|
|
@ -1,3 +1,4 @@
|
|||
from contextlib import contextmanager
|
||||
from typing import List, Dict, Optional
|
||||
import jedi
|
||||
import parso
|
||||
|
|
@ -6,7 +7,17 @@ from pathlib import Path
|
|||
from parso.tree import BaseNode
|
||||
|
||||
|
||||
def get_code_entities(node: parso.tree.NodeOrLeaf) -> List[Dict[str, any]]:
|
||||
@contextmanager
|
||||
def add_sys_path(path):
|
||||
original_sys_path = sys.path.copy()
|
||||
sys.path.insert(0, path)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.path = original_sys_path
|
||||
|
||||
|
||||
def _get_code_entities(node: parso.tree.NodeOrLeaf) -> List[Dict[str, any]]:
|
||||
"""
|
||||
Recursively extract code entities using parso.
|
||||
"""
|
||||
|
|
@ -27,12 +38,12 @@ def get_code_entities(node: parso.tree.NodeOrLeaf) -> List[Dict[str, any]]:
|
|||
|
||||
# Recursively process child nodes
|
||||
for child in node.children:
|
||||
code_entity_list.extend(get_code_entities(child))
|
||||
code_entity_list.extend(_get_code_entities(child))
|
||||
|
||||
return code_entity_list
|
||||
|
||||
|
||||
def update_code_entity(script: jedi.Script, code_entity: Dict[str, any]) -> None:
|
||||
def _update_code_entity(script: jedi.Script, code_entity: Dict[str, any]) -> None:
|
||||
"""
|
||||
Update a code_entity with (full_name, module_name, module_path) using Jedi
|
||||
"""
|
||||
|
|
@ -42,35 +53,38 @@ def update_code_entity(script: jedi.Script, code_entity: Dict[str, any]) -> None
|
|||
code_entity["module_name"] = getattr(results[0], "module_name", None)
|
||||
code_entity["module_path"] = getattr(results[0], "module_path", None)
|
||||
|
||||
|
||||
def get_local_script_dependencies(script_path: str, repo_path: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Extract and return a list of unique module paths that the script depends on.
|
||||
"""
|
||||
if repo_path:
|
||||
sys.path.insert(0, str(Path(repo_path).resolve()))
|
||||
|
||||
def _extract_dependencies(script_path: str) -> List[str]:
|
||||
with open(script_path, "r") as file:
|
||||
source_code = file.read()
|
||||
|
||||
script = jedi.Script(code=source_code, path=script_path)
|
||||
|
||||
tree = parso.parse(source_code)
|
||||
code_entities = get_code_entities(tree)
|
||||
code_entities = _get_code_entities(tree)
|
||||
|
||||
for code_entity in code_entities:
|
||||
update_code_entity(script, code_entity)
|
||||
_update_code_entity(script, code_entity)
|
||||
|
||||
module_paths = {
|
||||
entity.get("module_path")
|
||||
for entity in code_entities
|
||||
if entity.get("module_path")
|
||||
}
|
||||
if repo_path:
|
||||
repo_path_resolved = str(Path(repo_path).resolve(strict=False))
|
||||
module_paths = {path for path in module_paths if str(path).startswith(repo_path_resolved)}
|
||||
|
||||
return sorted(path for path in module_paths if path)
|
||||
return sorted(str(path) for path in module_paths)
|
||||
|
||||
def get_local_script_dependencies(script_path: str, repo_path: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Extract and return a list of unique module paths that the script depends on.
|
||||
"""
|
||||
if repo_path:
|
||||
repo_path_resolved = str(Path(repo_path).resolve())
|
||||
with add_sys_path(repo_path_resolved):
|
||||
dependencies = _extract_dependencies(script_path)
|
||||
dependencies = [path for path in dependencies if path.startswith(repo_path_resolved)]
|
||||
else:
|
||||
dependencies = _extract_dependencies(script_path)
|
||||
return dependencies
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Simple execution example, use absolute paths
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue