diff --git a/cognee/tasks/repo_processor/local_script_dependencies.py b/cognee/tasks/repo_processor/local_script_dependencies.py index 244431212..4beb584f4 100644 --- a/cognee/tasks/repo_processor/local_script_dependencies.py +++ b/cognee/tasks/repo_processor/local_script_dependencies.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from typing import List, Dict, Optional import jedi import parso @@ -6,7 +7,17 @@ from pathlib import Path from parso.tree import BaseNode -def get_code_entities(node: parso.tree.NodeOrLeaf) -> List[Dict[str, any]]: +@contextmanager +def add_sys_path(path): + original_sys_path = sys.path.copy() + sys.path.insert(0, path) + try: + yield + finally: + sys.path = original_sys_path + + +def _get_code_entities(node: parso.tree.NodeOrLeaf) -> List[Dict[str, any]]: """ Recursively extract code entities using parso. """ @@ -27,12 +38,12 @@ def get_code_entities(node: parso.tree.NodeOrLeaf) -> List[Dict[str, any]]: # Recursively process child nodes for child in node.children: - code_entity_list.extend(get_code_entities(child)) + code_entity_list.extend(_get_code_entities(child)) return code_entity_list -def update_code_entity(script: jedi.Script, code_entity: Dict[str, any]) -> None: +def _update_code_entity(script: jedi.Script, code_entity: Dict[str, any]) -> None: """ Update a code_entity with (full_name, module_name, module_path) using Jedi """ @@ -42,35 +53,38 @@ def update_code_entity(script: jedi.Script, code_entity: Dict[str, any]) -> None code_entity["module_name"] = getattr(results[0], "module_name", None) code_entity["module_path"] = getattr(results[0], "module_path", None) - -def get_local_script_dependencies(script_path: str, repo_path: Optional[str] = None) -> List[str]: - """ - Extract and return a list of unique module paths that the script depends on. - """ - if repo_path: - sys.path.insert(0, str(Path(repo_path).resolve())) - +def _extract_dependencies(script_path: str) -> List[str]: with open(script_path, "r") as file: source_code = file.read() script = jedi.Script(code=source_code, path=script_path) tree = parso.parse(source_code) - code_entities = get_code_entities(tree) + code_entities = _get_code_entities(tree) for code_entity in code_entities: - update_code_entity(script, code_entity) + _update_code_entity(script, code_entity) module_paths = { entity.get("module_path") for entity in code_entities if entity.get("module_path") } - if repo_path: - repo_path_resolved = str(Path(repo_path).resolve(strict=False)) - module_paths = {path for path in module_paths if str(path).startswith(repo_path_resolved)} - return sorted(path for path in module_paths if path) + return sorted(str(path) for path in module_paths) + +def get_local_script_dependencies(script_path: str, repo_path: Optional[str] = None) -> List[str]: + """ + Extract and return a list of unique module paths that the script depends on. + """ + if repo_path: + repo_path_resolved = str(Path(repo_path).resolve()) + with add_sys_path(repo_path_resolved): + dependencies = _extract_dependencies(script_path) + dependencies = [path for path in dependencies if path.startswith(repo_path_resolved)] + else: + dependencies = _extract_dependencies(script_path) + return dependencies if __name__ == "__main__": # Simple execution example, use absolute paths