1818import numpy as np
1919import onnxruntime as ort
2020import OpenEXR
21- import requests
2221import trimesh
2322import vtk
2423from PIL import Image
2524from vtk .util .numpy_support import numpy_to_vtk
2625
2726from openlifu .util .annotations import OpenLIFUFieldData
27+ from openlifu .util .assets import download_and_install_modnet , get_modnet_path
2828
2929logger_meshrecon = logging .getLogger ("MeshRecon" )
3030logger_meshroom = logging .getLogger ("Meshroom" )
@@ -300,6 +300,7 @@ def run_reconstruction(
300300 locations : List [Tuple [float , float , float ]] | None = None ,
301301 return_durations : bool = False ,
302302 progress_callback : Callable [[int ,str ],None ] | None = None ,
303+ download_masking_model : bool = True ,
303304) -> Tuple [Photoscan , Path ] | Tuple [Photoscan , Path , Dict [str , float ]]:
304305 """Run Meshroom with the given images and pipeline.
305306 Args:
@@ -322,6 +323,8 @@ def run_reconstruction(
322323 return_durations (bool): If True, also return a dictionary mapping node names to durations in seconds.
323324 progress_callback: An optional function that will be called to report progress. The function should accept two arguments:
324325 an integer progress value from 0 to 100 followed by a string message describing the step currently being worked on.
326+ download_masking_model: Whether to auto-download the masking model weights if they are not present;
327+ only relevant if use_masks is enabled.
325328
326329 Returns:
327330 Union[Tuple[Photoscan, Path], Tuple[Photoscan, Path, Dict[str, float]]]:
@@ -432,7 +435,7 @@ def progress_callback(progress_percent : int, step_description : str): # noqa: A
432435 start_time = time .perf_counter ()
433436 masks_dir = temp_dir / "masks"
434437 masks_dir .mkdir (parents = True , exist_ok = True )
435- make_masks (new_paths , masks_dir )
438+ make_masks (new_paths , masks_dir , download_model = download_masking_model )
436439 command .append ( f"PrepareDenseScene_1.masksFolders=['{ masks_dir .as_posix ()} ']" )
437440 durations ["MaskCreation" ] = time .perf_counter () - start_time
438441
@@ -597,38 +600,6 @@ def inverse_transform(img):
597600
598601 return inverse_transform (image ) if inverse else transform (image )
599602
600-
601- def get_modnet_path () -> Path :
602- """Get the MODNet checkpoint path. Download it if not present.
603- """
604- package = "openlifu.nav.modnet_checkpoints"
605- filename = "modnet_photographic_portrait_matting.onnx"
606- url = "https://data.kitware.com/api/v1/file/67feb2cb31a330568827ab32/download"
607- try :
608- # Try to find the checkpoint in the package
609- resource_path = importlib .resources .files (package ) / filename
610- if resource_path .is_file ():
611- logger_meshrecon .info (f"Found existing MODNet checkpoint at { resource_path } " )
612- return resource_path
613- except (FileNotFoundError , ModuleNotFoundError ):
614- pass
615-
616- # Fallback: Download the checkpoint
617- base_dir = Path (importlib .resources .files (package ))
618- full_path = base_dir / filename
619- logger_meshrecon .info (f"MODNet checkpoint not found. Downloading from { url } ..." )
620- response = requests .get (url , stream = True , timeout = (10 , 300 ))
621- if response .status_code == 200 :
622- with open (full_path , 'wb' ) as f :
623- for chunk in response .iter_content (chunk_size = 8192 ):
624- if chunk :
625- f .write (chunk )
626- logger_meshrecon .info (f"Downloaded MODNet checkpoint to { full_path } " )
627- else :
628- raise RuntimeError (f"Failed to download MODNet checkpoint: { response .status_code } - { response .text } " )
629-
630- return full_path
631-
632603def preprocess_image_modnet (image : np .ndarray , ref_size : int = 512 ) -> np .ndarray :
633604 """
634605 Preprocess an input image for MODNet inference.
@@ -675,7 +646,7 @@ def preprocess_image_modnet(image: np.ndarray, ref_size: int = 512) -> np.ndarra
675646 return image
676647
677648
678- def make_masks (image_paths : list [Path ], output_dir : Path , threshold : float = 0.01 ) -> None :
649+ def make_masks (image_paths : list [Path ], output_dir : Path , threshold : float = 0.01 , download_model = True ) -> None :
679650 """
680651 Runs MODNet on a list of image paths and saves the output masks.
681652
@@ -687,10 +658,23 @@ def make_masks(image_paths: list[Path], output_dir: Path, threshold: float = 0.0
687658 image_paths (List[str]): List of input image file paths.
688659 output_dir (str): Directory where the output masks will be saved.
689660 threshold (float): Threshold to binarize the soft segmentation output.
661+ download_model (bool): Whether to auto-download the model weights if they are not present.
690662 """
663+
691664 # Load the ONNX model
665+
692666 ckpt_path = get_modnet_path ()
667+ if not ckpt_path .exists ():
668+ if download_model :
669+ logger_meshrecon .info (f"Downloading MODNet checkpoint to { ckpt_path } " )
670+ download_and_install_modnet ()
671+ else :
672+ raise FileNotFoundError (f"MODNet checkpoint not found at { ckpt_path } . Install it using an appropirate utility in openlifu.util.assets." )
673+ else :
674+ logger_meshrecon .info (f"Found existing MODNet checkpoint at { ckpt_path } " )
675+
693676 session = ort .InferenceSession (ckpt_path , providers = ["CPUExecutionProvider" ]) # or CUDAExecutionProvider
677+
694678 for image_path in image_paths :
695679 image = Image .open (image_path )
696680 exif = image .getexif ()
0 commit comments