Skip to content

Commit ce1060f

Browse files
Add MODNet checkpoint asset utilities (#377)
This allows you to now call run_reconstruction with download_masking_model set to False. This will skip the download of the MODNet weights, instead requiring the user to somehow ensure the model weights are installed. They can be installed from file using the newly added openlifu.util.assets.install_modnet_from_file.
1 parent 5d8795c commit ce1060f

2 files changed

Lines changed: 45 additions & 35 deletions

File tree

src/openlifu/nav/photoscan.py

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818
import numpy as np
1919
import onnxruntime as ort
2020
import OpenEXR
21-
import requests
2221
import trimesh
2322
import vtk
2423
from PIL import Image
2524
from vtk.util.numpy_support import numpy_to_vtk
2625

2726
from openlifu.util.annotations import OpenLIFUFieldData
27+
from openlifu.util.assets import download_and_install_modnet, get_modnet_path
2828

2929
logger_meshrecon = logging.getLogger("MeshRecon")
3030
logger_meshroom = logging.getLogger("Meshroom")
@@ -300,6 +300,7 @@ def run_reconstruction(
300300
locations: List[Tuple[float, float, float]] | None = None,
301301
return_durations: bool = False,
302302
progress_callback : Callable[[int,str],None] | None = None,
303+
download_masking_model: bool = True,
303304
) -> Tuple[Photoscan, Path] | Tuple[Photoscan, Path, Dict[str, float]]:
304305
"""Run Meshroom with the given images and pipeline.
305306
Args:
@@ -322,6 +323,8 @@ def run_reconstruction(
322323
return_durations (bool): If True, also return a dictionary mapping node names to durations in seconds.
323324
progress_callback: An optional function that will be called to report progress. The function should accept two arguments:
324325
an integer progress value from 0 to 100 followed by a string message describing the step currently being worked on.
326+
download_masking_model: Whether to auto-download the masking model weights if they are not present;
327+
only relevant if use_masks is enabled.
325328
326329
Returns:
327330
Union[Tuple[Photoscan, Path], Tuple[Photoscan, Path, Dict[str, float]]]:
@@ -432,7 +435,7 @@ def progress_callback(progress_percent : int, step_description : str): # noqa: A
432435
start_time = time.perf_counter()
433436
masks_dir = temp_dir / "masks"
434437
masks_dir.mkdir(parents=True, exist_ok=True)
435-
make_masks(new_paths, masks_dir)
438+
make_masks(new_paths, masks_dir, download_model=download_masking_model)
436439
command.append( f"PrepareDenseScene_1.masksFolders=['{masks_dir.as_posix()}']" )
437440
durations["MaskCreation"] = time.perf_counter() - start_time
438441

@@ -597,38 +600,6 @@ def inverse_transform(img):
597600

598601
return inverse_transform(image) if inverse else transform(image)
599602

600-
601-
def get_modnet_path() -> Path:
602-
"""Get the MODNet checkpoint path. Download it if not present.
603-
"""
604-
package = "openlifu.nav.modnet_checkpoints"
605-
filename = "modnet_photographic_portrait_matting.onnx"
606-
url = "https://data.kitware.com/api/v1/file/67feb2cb31a330568827ab32/download"
607-
try:
608-
# Try to find the checkpoint in the package
609-
resource_path = importlib.resources.files(package) / filename
610-
if resource_path.is_file():
611-
logger_meshrecon.info(f"Found existing MODNet checkpoint at {resource_path}")
612-
return resource_path
613-
except (FileNotFoundError, ModuleNotFoundError):
614-
pass
615-
616-
# Fallback: Download the checkpoint
617-
base_dir = Path(importlib.resources.files(package))
618-
full_path = base_dir / filename
619-
logger_meshrecon.info(f"MODNet checkpoint not found. Downloading from {url}...")
620-
response = requests.get(url, stream=True, timeout=(10, 300))
621-
if response.status_code == 200:
622-
with open(full_path, 'wb') as f:
623-
for chunk in response.iter_content(chunk_size=8192):
624-
if chunk:
625-
f.write(chunk)
626-
logger_meshrecon.info(f"Downloaded MODNet checkpoint to {full_path}")
627-
else:
628-
raise RuntimeError(f"Failed to download MODNet checkpoint: {response.status_code} - {response.text}")
629-
630-
return full_path
631-
632603
def preprocess_image_modnet(image: np.ndarray, ref_size: int = 512) -> np.ndarray:
633604
"""
634605
Preprocess an input image for MODNet inference.
@@ -675,7 +646,7 @@ def preprocess_image_modnet(image: np.ndarray, ref_size: int = 512) -> np.ndarra
675646
return image
676647

677648

678-
def make_masks(image_paths: list[Path], output_dir: Path, threshold: float = 0.01) -> None:
649+
def make_masks(image_paths: list[Path], output_dir: Path, threshold: float = 0.01, download_model=True) -> None:
679650
"""
680651
Runs MODNet on a list of image paths and saves the output masks.
681652
@@ -687,10 +658,23 @@ def make_masks(image_paths: list[Path], output_dir: Path, threshold: float = 0.0
687658
image_paths (List[str]): List of input image file paths.
688659
output_dir (str): Directory where the output masks will be saved.
689660
threshold (float): Threshold to binarize the soft segmentation output.
661+
download_model (bool): Whether to auto-download the model weights if they are not present.
690662
"""
663+
691664
# Load the ONNX model
665+
692666
ckpt_path = get_modnet_path()
667+
if not ckpt_path.exists():
668+
if download_model:
669+
logger_meshrecon.info(f"Downloading MODNet checkpoint to {ckpt_path}")
670+
download_and_install_modnet()
671+
else:
672+
raise FileNotFoundError(f"MODNet checkpoint not found at {ckpt_path}. Install it using an appropirate utility in openlifu.util.assets.")
673+
else:
674+
logger_meshrecon.info(f"Found existing MODNet checkpoint at {ckpt_path}")
675+
693676
session = ort.InferenceSession(ckpt_path, providers=["CPUExecutionProvider"]) # or CUDAExecutionProvider
677+
694678
for image_path in image_paths:
695679
image = Image.open(image_path)
696680
exif = image.getexif()

src/openlifu/util/assets.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import importlib
56
import shutil
67
import tempfile
78
from pathlib import Path
@@ -50,3 +51,28 @@ def install_asset(destination:PathLike, path_to_asset:PathLike|None, url_to_asse
5051
Path(temp_file_path).unlink()
5152
else:
5253
raise ValueError("Either path_to_asset or url_to_asset must be provided.")
54+
55+
56+
def get_modnet_path() -> Path:
57+
"""Get the MODNet checkpoint path.
58+
It may or may not exist; see `download_and_install_modnet` and `install_modnet_from_file`.
59+
If `get_modnet_path().exists()` is False, then use one of those two options to install.
60+
"""
61+
package = "openlifu.nav.modnet_checkpoints"
62+
filename = "modnet_photographic_portrait_matting.onnx"
63+
base_dir = Path(importlib.resources.files(package))
64+
return base_dir / filename
65+
66+
def download_and_install_modnet() -> Path:
67+
"""Download and install the MODNet checkpoint. Returns path to installed MODNet checkpoint."""
68+
url = "https://data.kitware.com/api/v1/file/67feb2cb31a330568827ab32/download"
69+
modnet_path = get_modnet_path()
70+
install_asset(modnet_path, url_to_asset=url)
71+
return modnet_path
72+
73+
def install_modnet_from_file(path_to_modnet_file:PathLike) -> Path:
74+
"""Copy MODNet checkpoint to the appropriate place for openlifu to use it.
75+
Returns path to installed MODNet checkpoint."""
76+
modnet_path = get_modnet_path()
77+
install_asset(modnet_path, path_to_asset=path_to_modnet_file)
78+
return modnet_path

0 commit comments

Comments
 (0)