-
Notifications
You must be signed in to change notification settings - Fork 8
Changed workflow for setting the masking value in the skull stripping #171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
c5859c4
2395565
776cf77
9a523df
f8c99d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| from pathlib import Path | ||
| from typing import Optional, Union | ||
| from enum import Enum | ||
| import numpy as np | ||
|
|
||
| from auxiliary.io import read_image, write_image | ||
| from brainles_hd_bet import run_hd_bet | ||
|
|
@@ -14,7 +15,23 @@ class Mode(Enum): | |
| ACCURATE = "accurate" | ||
|
|
||
|
|
||
| class BrainExtractor: | ||
| class BrainExtractor(ABC): | ||
| def __init__( | ||
| self, | ||
| masking_value: Optional[Union[int, float]] = None, | ||
| ): | ||
| """ | ||
| Base class for skull stripping medical images using brain masks. | ||
|
|
||
| Subclasses should implement the `extract` method to generate a skull stripped image | ||
| based on the provided input image and mask. | ||
| """ | ||
| # Just as in the defacer, masking value is a global value defined across all images and modalities | ||
| # If no value is passed, the minimum of a given input image is chosen | ||
| # TODO: Consider extending this to modality-specific masking values in the future, this should | ||
| # probably be implemented as a property of the the specific modality | ||
| self.masking_value = masking_value | ||
|
|
||
| @abstractmethod | ||
| def extract( | ||
| self, | ||
|
|
@@ -63,8 +80,17 @@ def apply_mask( | |
| if input_data.shape != mask_data.shape: | ||
| raise ValueError("Input image and mask must have the same dimensions.") | ||
|
|
||
| # Mask and save it | ||
| masked_data = input_data * mask_data | ||
| # check whether a global masking value was passed, otherwise choose minimum | ||
| if self.masking_value is None: | ||
| current_masking_value = np.min(input_data) | ||
| else: | ||
| current_masking_value = ( | ||
| np.array(self.masking_value).astype(input_data.dtype).item() | ||
| ) | ||
| # Apply mask (element-wise either input or masking value) | ||
| masked_data = np.where( | ||
| mask_data.astype(bool), input_data, current_masking_value | ||
| ) | ||
|
Comment on lines
+83
to
+93
|
||
|
|
||
| try: | ||
| write_image( | ||
|
|
@@ -78,6 +104,15 @@ def apply_mask( | |
|
|
||
|
|
||
| class HDBetExtractor(BrainExtractor): | ||
| def __init__(self, masking_value: Optional[Union[int, float]] = None): | ||
| """ | ||
| Brain extraction HDBet implementation. | ||
|
|
||
| Args: | ||
| masking_value (Optional[Union[int, float]], optional): global value to be inserted in the masked areas. Default is None which leads to the minimum of each respective image. | ||
| """ | ||
| super().__init__(masking_value=masking_value) | ||
|
|
||
| def extract( | ||
| self, | ||
| input_image_path: Union[str, Path], | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,7 +21,9 @@ | |
|
|
||
| class SynthStripExtractor(BrainExtractor): | ||
|
|
||
| def __init__(self, border: int = 1): | ||
| def __init__( | ||
| self, border: int = 1, masking_value: Optional[Union[int, float]] = None | ||
| ): | ||
|
Comment on lines
+24
to
+26
|
||
| """ | ||
| Brain extraction using SynthStrip with preprocessing conforming to model requirements. | ||
|
|
||
|
|
@@ -31,9 +33,10 @@ def __init__(self, border: int = 1): | |
|
|
||
| Args: | ||
| border (int): Mask border threshold in mm. Defaults to 1. | ||
| """ | ||
| masking_value (Optional[Union[int, float]], optional): global value to be inserted in the masked areas. Default is None which leads to the minimum of each respective image. | ||
|
|
||
| super().__init__() | ||
| """ | ||
| super().__init__(masking_value=masking_value) | ||
| self.border = border | ||
|
|
||
| def _setup_model(self, device: torch.device) -> StripModel: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.