Skip to content

Commit 67750ec

Browse files
Add kwave inert import utility (#377)
This is a way of importing kwave while skipping its automatic download of binaries. We need to do this in order to get the binary paths from kwave to support offline use, or delayed download after a prompt, etc. It's not pretty, but even if we upgrade to the latest k-wave it should still work. I will add a unit test that uses this to make sure we catch any issues while upgrading kwave.
1 parent ce1060f commit 67750ec

1 file changed

Lines changed: 58 additions & 0 deletions

File tree

src/openlifu/util/assets.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
from __future__ import annotations
44

5+
import ast
56
import importlib
67
import shutil
8+
import sys
79
import tempfile
810
from pathlib import Path
11+
from types import ModuleType
912

1013
import requests
1114

@@ -76,3 +79,58 @@ def install_modnet_from_file(path_to_modnet_file:PathLike) -> Path:
7679
modnet_path = get_modnet_path()
7780
install_asset(modnet_path, path_to_asset=path_to_modnet_file)
7881
return modnet_path
82+
83+
def _import_without_calls(pkg: str, banned_calls:list[str], register=False) -> ModuleType:
84+
"""Import `pkg` but strip any top-level statements that call a banned function.
85+
86+
It is simplistic: it is looking at the syntax tree and stripping out any node that
87+
has a banned function call in any of its descendent nodes. There are lots of ways to break
88+
this if there is enough misdirection in a banned function call. The point of this is just
89+
to help handle a specific issue we have with kwave's binary download.
90+
91+
Args:
92+
pkg: The name of the package to import
93+
banned_calls: A list of functions to import
94+
register: Whether to add the module in global import registry.
95+
Doing so makes any future imports of the module via the usual `import`
96+
statement end up referring to the version imported here.
97+
98+
Returns the module.
99+
"""
100+
spec = importlib.util.find_spec(pkg)
101+
if not spec or not spec.submodule_search_locations:
102+
raise ImportError(f"Can't find package {pkg!r}")
103+
104+
init_path = Path(spec.submodule_search_locations[0]) / "__init__.py"
105+
src = init_path.read_text(encoding="utf-8")
106+
tree = ast.parse(src, filename=str(init_path))
107+
108+
# this function tells whether a top level statement tries to call a banned function anywhere inside it
109+
def stmt_calls_banned(stmt: ast.stmt) -> bool:
110+
for node in ast.walk(stmt):
111+
if isinstance(node, ast.Call):
112+
f = node.func
113+
if isinstance(f, ast.Name) and f.id in banned_calls:
114+
return True
115+
if isinstance(f, ast.Attribute) and f.attr in banned_calls:
116+
return True
117+
return False
118+
119+
tree.body = [s for s in tree.body if not stmt_calls_banned(s)] # strip out offending top level statements
120+
code = compile(tree, str(init_path), "exec")
121+
122+
module = ModuleType(pkg) # create a blank module object
123+
module.__file__ = str(init_path)
124+
module.__package__ = pkg
125+
g = module.__dict__ # build up the context in which we will execute the module code
126+
g["__name__"] = pkg
127+
g["__file__"] = str(init_path)
128+
exec(code, g, g)
129+
130+
if register:
131+
sys.modules[pkg] = module
132+
return module
133+
134+
def _import_kwave_inertly() -> ModuleType:
135+
"""Import kwave without allowing it to install binaries"""
136+
return _import_without_calls("kwave", banned_calls=["install_binaries"])

0 commit comments

Comments
 (0)