From 8f95fb140cfcf48b95ab5b53e8e17485754779e3 Mon Sep 17 00:00:00 2001 From: ugbotueferhire Date: Mon, 25 May 2026 01:05:12 +0100 Subject: [PATCH 1/3] Fix diffusion inferers to support diffusers-style schedulers --- monai/inferers/inferer.py | 69 +++++++++++++++---- tests/inferers/test_diffusion_inferer.py | 42 +++++++++++ .../inferers/test_latent_diffusion_inferer.py | 58 ++++++++++++++++ 3 files changed, 157 insertions(+), 12 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index ee94b1ebdb..3b2d6ec447 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -11,6 +11,7 @@ from __future__ import annotations +import inspect import math import warnings from abc import ABC, abstractmethod @@ -861,6 +862,42 @@ def __init__(self, scheduler: Scheduler) -> None: # type: ignore[override] self.scheduler = scheduler + @staticmethod + def _scheduler_step_supports_kwarg(scheduler: Scheduler, kwarg: str) -> bool: + try: + return kwarg in inspect.signature(scheduler.step).parameters + except (TypeError, ValueError): + return False + + @staticmethod + def _get_previous_sample_from_step_output(step_output: Any) -> torch.Tensor: + if isinstance(step_output, tuple): + return step_output[0] + if isinstance(step_output, Mapping): + return step_output["prev_sample"] + if hasattr(step_output, "prev_sample"): + return step_output.prev_sample + raise TypeError("Unsupported scheduler.step output. Expected a tuple or an object with `prev_sample`.") + + def _scheduler_step( + self, + scheduler: Scheduler, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.Tensor, + next_timestep: int | torch.Tensor | None = None, + ) -> torch.Tensor: + step_kwargs = {} + if self._scheduler_step_supports_kwarg(scheduler, "return_dict"): + step_kwargs["return_dict"] = False + + if isinstance(scheduler, RFlowScheduler): + step_output = scheduler.step(model_output, timestep, sample, next_timestep, **step_kwargs) # type: ignore + else: + step_output = scheduler.step(model_output, timestep, sample, **step_kwargs) # type: ignore + + return self._get_previous_sample_from_step_output(step_output) + def __call__( # type: ignore[override] self, inputs: torch.Tensor, @@ -940,7 +977,12 @@ def sample( scheduler = self.scheduler image = input_noise - all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype))) + all_next_timesteps = torch.cat( + ( + scheduler.timesteps[1:], + torch.tensor([0], dtype=scheduler.timesteps.dtype, device=scheduler.timesteps.device), + ) + ) if verbose and has_tqdm: progress_bar = tqdm( zip(scheduler.timesteps, all_next_timesteps), @@ -984,10 +1026,9 @@ def sample( model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond) # 2. compute previous image: x_t -> x_t-1 - if not isinstance(scheduler, RFlowScheduler): - image, _ = scheduler.step(model_output, t, image) # type: ignore - else: - image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore + image = self._scheduler_step( + scheduler=scheduler, model_output=model_output, timestep=t, sample=image, next_timestep=next_t + ) if save_intermediates and t % intermediate_steps == 0: intermediates.append(image) @@ -1046,7 +1087,7 @@ def get_likelihood( total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() - noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + noisy_image = scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) diffusion_model = ( partial(diffusion_model, seg=seg) if isinstance(diffusion_model, SPADEDiffusionModelUNet) @@ -1509,7 +1550,12 @@ def sample( # type: ignore[override] scheduler = self.scheduler image = input_noise - all_next_timesteps = torch.cat((scheduler.timesteps[1:], torch.tensor([0], dtype=scheduler.timesteps.dtype))) + all_next_timesteps = torch.cat( + ( + scheduler.timesteps[1:], + torch.tensor([0], dtype=scheduler.timesteps.dtype, device=scheduler.timesteps.device), + ) + ) if verbose and has_tqdm: progress_bar = tqdm( zip(scheduler.timesteps, all_next_timesteps), @@ -1583,10 +1629,9 @@ def sample( # type: ignore[override] model_output = model_output_uncond + cfg * (model_output_cond - model_output_uncond) # 3. compute previous image: x_t -> x_t-1 - if not isinstance(scheduler, RFlowScheduler): - image, _ = scheduler.step(model_output, t, image) # type: ignore - else: - image, _ = scheduler.step(model_output, t, image, next_t) # type: ignore + image = self._scheduler_step( + scheduler=scheduler, model_output=model_output, timestep=t, sample=image, next_timestep=next_t + ) if save_intermediates and t % intermediate_steps == 0: intermediates.append(image) @@ -1647,7 +1692,7 @@ def get_likelihood( # type: ignore[override] total_kl = torch.zeros(inputs.shape[0]).to(inputs.device) for t in progress_bar: timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long() - noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + noisy_image = scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) diffuse = diffusion_model if isinstance(diffusion_model, SPADEDiffusionModelUNet): diff --git a/tests/inferers/test_diffusion_inferer.py b/tests/inferers/test_diffusion_inferer.py index 81874ed3a8..5099a482e0 100644 --- a/tests/inferers/test_diffusion_inferer.py +++ b/tests/inferers/test_diffusion_inferer.py @@ -55,6 +55,31 @@ ] +class DiffusersLikeSchedulerOutput: + def __init__(self, prev_sample: torch.Tensor, pred_original_sample: torch.Tensor) -> None: + self.prev_sample = prev_sample + self.pred_original_sample = pred_original_sample + + +class DiffusersStyleDDPMScheduler(DDPMScheduler): + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + generator: torch.Generator | None = None, + return_dict: bool = True, + ): + prev_sample, pred_original_sample = super().step( + model_output=model_output, timestep=timestep, sample=sample, generator=generator + ) + if return_dict: + return DiffusersLikeSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) + return prev_sample, pred_original_sample + + class TestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") @@ -126,6 +151,23 @@ def test_ddpm_sampler(self, model_params, input_shape): ) self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) + @skipUnless(has_einops, "Requires einops") + def test_diffusers_style_ddpm_sampler(self, model_params, input_shape): + model = DiffusionModelUNet(**model_params) + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model.to(device) + model.eval() + noise = torch.randn(input_shape).to(device) + scheduler = DiffusersStyleDDPMScheduler(num_train_timesteps=1000) + inferer = DiffusionInferer(scheduler=scheduler) + scheduler.set_timesteps(num_inference_steps=10) + sample, intermediates = inferer.sample( + input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + ) + self.assertEqual(sample.shape, noise.shape) + self.assertEqual(len(intermediates), 10) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_ddim_sampler(self, model_params, input_shape): diff --git a/tests/inferers/test_latent_diffusion_inferer.py b/tests/inferers/test_latent_diffusion_inferer.py index ab80363cde..e12e2b963c 100644 --- a/tests/inferers/test_latent_diffusion_inferer.py +++ b/tests/inferers/test_latent_diffusion_inferer.py @@ -313,6 +313,33 @@ ], ] +TEST_CASES_DIFFUSERS = [TEST_CASES[0]] + + +class DiffusersLikeSchedulerOutput: + def __init__(self, prev_sample: torch.Tensor, pred_original_sample: torch.Tensor) -> None: + self.prev_sample = prev_sample + self.pred_original_sample = pred_original_sample + + +class DiffusersStyleDDPMScheduler(DDPMScheduler): + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + generator: torch.Generator | None = None, + return_dict: bool = True, + ): + prev_sample, pred_original_sample = super().step( + model_output=model_output, timestep=timestep, sample=sample, generator=generator + ) + if return_dict: + return DiffusersLikeSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) + return prev_sample, pred_original_sample + class TestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) @@ -414,6 +441,37 @@ def test_sample_shape( ) self.assertEqual(sample.shape, input_shape) + @parameterized.expand(TEST_CASES_DIFFUSERS) + @skipUnless(has_einops, "Requires einops") + def test_diffusers_style_ddpm_sample_shape( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + else: + stage_1 = VQVAE(**autoencoder_params) + + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + scheduler = DiffusersStyleDDPMScheduler(num_train_timesteps=1000) + inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) + scheduler.set_timesteps(num_inference_steps=10) + + sample = inferer.sample( + input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler + ) + self.assertEqual(sample.shape, input_shape) + @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") def test_sample_shape_with_cfg( From a17e9185f3dce1810b121603b627892482a2c0f4 Mon Sep 17 00:00:00 2001 From: ugbotueferhire Date: Sun, 31 May 2026 12:53:58 +0100 Subject: [PATCH 2/3] Support diffusers schedulers in inferers Signed-off-by: ugbotueferhire --- monai/inferers/inferer.py | 116 ++++++++++++++---- tests/inferers/test_diffusion_inferer.py | 84 ++++++++----- .../inferers/test_latent_diffusion_inferer.py | 73 +++++------ 3 files changed, 170 insertions(+), 103 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 3b2d6ec447..f023db490e 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -879,6 +879,60 @@ def _get_previous_sample_from_step_output(step_output: Any) -> torch.Tensor: return step_output.prev_sample raise TypeError("Unsupported scheduler.step output. Expected a tuple or an object with `prev_sample`.") + @staticmethod + def _get_scheduler_name(scheduler: Scheduler) -> str: + if hasattr(scheduler, "_get_name"): + return scheduler._get_name() + return scheduler.__class__.__name__ + + @staticmethod + def _get_scheduler_config_value(scheduler: Scheduler, name: str, default: Any = None) -> Any: + config = getattr(scheduler, "config", None) + if isinstance(config, Mapping): + if name in config: + return config[name] + elif config is not None and hasattr(config, name): + return getattr(config, name) + + if hasattr(scheduler, name): + return getattr(scheduler, name) + return default + + @staticmethod + def _get_posterior_mean( + scheduler: Scheduler, timestep: int | torch.Tensor, x_0: torch.Tensor, x_t: torch.Tensor + ) -> torch.Tensor: + alpha_t = scheduler.alphas[timestep] + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = scheduler.alphas_cumprod[timestep - 1] if timestep > 0 else scheduler.one + + x_0_coefficient = alpha_prod_t_prev.sqrt() * scheduler.betas[timestep] / (1 - alpha_prod_t) + x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) + + return x_0_coefficient * x_0 + x_t_coefficient * x_t + + def _get_posterior_variance( + self, scheduler: Scheduler, timestep: int | torch.Tensor, predicted_variance: torch.Tensor | None = None + ) -> torch.Tensor: + alpha_prod_t = scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = scheduler.alphas_cumprod[timestep - 1] if timestep > 0 else scheduler.one + variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * scheduler.betas[timestep] + variance_type = self._get_scheduler_config_value(scheduler, "variance_type") + + if variance_type == "fixed_small": + variance = torch.clamp(variance, min=1e-20) + elif variance_type == "fixed_large": + variance = scheduler.betas[timestep] + elif variance_type == "learned" and predicted_variance is not None: + return predicted_variance + elif variance_type == "learned_range" and predicted_variance is not None: + min_log = variance + max_log = scheduler.betas[timestep] + frac = (predicted_variance + 1) / 2 + variance = frac * max_log + (1 - frac) * min_log + + return variance + def _scheduler_step( self, scheduler: Scheduler, @@ -1069,10 +1123,10 @@ def get_likelihood( if not scheduler: scheduler = self.scheduler - if scheduler._get_name() != "DDPMScheduler": + scheduler_name = self._get_scheduler_name(scheduler) + if scheduler_name != "DDPMScheduler": raise NotImplementedError( - f"Likelihood computation is only compatible with DDPMScheduler," - f" you are using {scheduler._get_name()}" + f"Likelihood computation is only compatible with DDPMScheduler," f" you are using {scheduler_name}" ) if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") @@ -1100,7 +1154,8 @@ def get_likelihood( model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning) # get the model's predicted mean, and variance if it is predicted - if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + variance_type = self._get_scheduler_config_value(scheduler, "variance_type") + if model_output.shape[1] == inputs.shape[1] * 2 and variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) else: predicted_variance = None @@ -1113,15 +1168,17 @@ def get_likelihood( # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if scheduler.prediction_type == "epsilon": + prediction_type = self._get_scheduler_config_value(scheduler, "prediction_type") + if prediction_type == "epsilon": pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif scheduler.prediction_type == "sample": + elif prediction_type == "sample": pred_original_sample = model_output - elif scheduler.prediction_type == "v_prediction": + elif prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output # 3. Clip "predicted x_0" - if scheduler.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + if self._get_scheduler_config_value(scheduler, "clip_sample"): + clip_sample_range = self._get_scheduler_config_value(scheduler, "clip_sample_range", 1.0) + pred_original_sample = torch.clamp(pred_original_sample, -clip_sample_range, clip_sample_range) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf @@ -1133,11 +1190,15 @@ def get_likelihood( predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image # get the posterior mean and variance - posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator] - posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator] + posterior_mean = self._get_posterior_mean(scheduler=scheduler, timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = self._get_posterior_variance( + scheduler=scheduler, timestep=t, predicted_variance=predicted_variance + ) log_posterior_variance = torch.log(posterior_variance) - log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + log_predicted_variance = ( + torch.log(predicted_variance) if predicted_variance is not None else log_posterior_variance + ) if t == 0: # compute -log p(x_0|x_1) @@ -1676,10 +1737,10 @@ def get_likelihood( # type: ignore[override] if not scheduler: scheduler = self.scheduler - if scheduler._get_name() != "DDPMScheduler": + scheduler_name = self._get_scheduler_name(scheduler) + if scheduler_name != "DDPMScheduler": raise NotImplementedError( - f"Likelihood computation is only compatible with DDPMScheduler," - f" you are using {scheduler._get_name()}" + f"Likelihood computation is only compatible with DDPMScheduler," f" you are using {scheduler_name}" ) if mode not in ["crossattn", "concat"]: raise NotImplementedError(f"{mode} condition is not supported") @@ -1725,7 +1786,8 @@ def get_likelihood( # type: ignore[override] mid_block_additional_residual=mid_block_res_sample, ) # get the model's predicted mean, and variance if it is predicted - if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]: + variance_type = self._get_scheduler_config_value(scheduler, "variance_type") + if model_output.shape[1] == inputs.shape[1] * 2 and variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1) else: predicted_variance = None @@ -1738,15 +1800,17 @@ def get_likelihood( # type: ignore[override] # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf - if scheduler.prediction_type == "epsilon": + prediction_type = self._get_scheduler_config_value(scheduler, "prediction_type") + if prediction_type == "epsilon": pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif scheduler.prediction_type == "sample": + elif prediction_type == "sample": pred_original_sample = model_output - elif scheduler.prediction_type == "v_prediction": + elif prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output # 3. Clip "predicted x_0" - if scheduler.clip_sample: - pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + if self._get_scheduler_config_value(scheduler, "clip_sample"): + clip_sample_range = self._get_scheduler_config_value(scheduler, "clip_sample_range", 1.0) + pred_original_sample = torch.clamp(pred_original_sample, -clip_sample_range, clip_sample_range) # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf @@ -1758,11 +1822,15 @@ def get_likelihood( # type: ignore[override] predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image # get the posterior mean and variance - posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image) # type: ignore[operator] - posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance) # type: ignore[operator] + posterior_mean = self._get_posterior_mean(scheduler=scheduler, timestep=t, x_0=inputs, x_t=noisy_image) + posterior_variance = self._get_posterior_variance( + scheduler=scheduler, timestep=t, predicted_variance=predicted_variance + ) log_posterior_variance = torch.log(posterior_variance) - log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + log_predicted_variance = ( + torch.log(predicted_variance) if predicted_variance is not None else log_posterior_variance + ) if t == 0: # compute -log p(x_0|x_1) diff --git a/tests/inferers/test_diffusion_inferer.py b/tests/inferers/test_diffusion_inferer.py index 5099a482e0..9e1a3072dd 100644 --- a/tests/inferers/test_diffusion_inferer.py +++ b/tests/inferers/test_diffusion_inferer.py @@ -24,6 +24,7 @@ _, has_scipy = optional_import("scipy") _, has_einops = optional_import("einops") +DiffusersDDPMScheduler, has_diffusers = optional_import("diffusers", name="DDPMScheduler") TEST_CASES = [ [ @@ -55,31 +56,6 @@ ] -class DiffusersLikeSchedulerOutput: - def __init__(self, prev_sample: torch.Tensor, pred_original_sample: torch.Tensor) -> None: - self.prev_sample = prev_sample - self.pred_original_sample = pred_original_sample - - -class DiffusersStyleDDPMScheduler(DDPMScheduler): - def step( - self, - model_output: torch.Tensor, - timestep: int, - sample: torch.Tensor, - generator: torch.Generator | None = None, - return_dict: bool = True, - ): - prev_sample, pred_original_sample = super().step( - model_output=model_output, timestep=timestep, sample=sample, generator=generator - ) - if return_dict: - return DiffusersLikeSchedulerOutput( - prev_sample=prev_sample, pred_original_sample=pred_original_sample - ) - return prev_sample, pred_original_sample - - class TestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") @@ -151,22 +127,62 @@ def test_ddpm_sampler(self, model_params, input_shape): ) self.assertEqual(len(intermediates), 10) - @parameterized.expand(TEST_CASES) - @skipUnless(has_einops, "Requires einops") - def test_diffusers_style_ddpm_sampler(self, model_params, input_shape): - model = DiffusionModelUNet(**model_params) + @skipUnless(has_einops and has_diffusers, "Requires einops and diffusers") + def test_diffusers_ddpm_call(self): device = "cuda:0" if torch.cuda.is_available() else "cpu" + model = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=[32, 64], + attention_levels=[False, True], + num_res_blocks=1, + num_head_channels=32, + ) model.to(device) model.eval() - noise = torch.randn(input_shape).to(device) - scheduler = DiffusersStyleDDPMScheduler(num_train_timesteps=1000) + scheduler = DiffusersDDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="epsilon") + scheduler.set_timesteps(num_inference_steps=50) + inferer = DiffusionInferer(scheduler=scheduler) + + batch_size = 2 + image_size = 32 + inputs = torch.randn(batch_size, 1, image_size, image_size).to(device) + noise = torch.randn_like(inputs) + timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,)).long().to(device) + with torch.no_grad(): + prediction = inferer(inputs=inputs, diffusion_model=model, noise=noise, timesteps=timesteps) + + self.assertEqual(prediction.shape, inputs.shape) + scheduler.set_timesteps(num_inference_steps=2) + sample = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=scheduler, verbose=False) + self.assertEqual(sample.shape, inputs.shape) + + @skipUnless(has_einops and has_diffusers, "Requires einops and diffusers") + def test_diffusers_ddpm_get_likelihood(self): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + model = DiffusionModelUNet( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=[8], + norm_num_groups=8, + attention_levels=[True], + num_res_blocks=1, + num_head_channels=8, + ) + model.to(device) + model.eval() + inputs = torch.randn(2, 1, 8, 8).to(device) + scheduler = DiffusersDDPMScheduler(num_train_timesteps=10, beta_schedule="linear", prediction_type="epsilon") inferer = DiffusionInferer(scheduler=scheduler) scheduler.set_timesteps(num_inference_steps=10) - sample, intermediates = inferer.sample( - input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1 + likelihood, intermediates = inferer.get_likelihood( + inputs=inputs, diffusion_model=model, scheduler=scheduler, save_intermediates=True ) - self.assertEqual(sample.shape, noise.shape) self.assertEqual(len(intermediates), 10) + self.assertEqual(intermediates[0].shape, inputs.shape) + self.assertEqual(likelihood.shape[0], inputs.shape[0]) @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") diff --git a/tests/inferers/test_latent_diffusion_inferer.py b/tests/inferers/test_latent_diffusion_inferer.py index e12e2b963c..23dd594d8e 100644 --- a/tests/inferers/test_latent_diffusion_inferer.py +++ b/tests/inferers/test_latent_diffusion_inferer.py @@ -23,6 +23,7 @@ from monai.utils import optional_import _, has_einops = optional_import("einops") +DiffusersDDPMScheduler, has_diffusers = optional_import("diffusers", name="DDPMScheduler") TEST_CASES = [ [ "AutoencoderKL", @@ -313,33 +314,6 @@ ], ] -TEST_CASES_DIFFUSERS = [TEST_CASES[0]] - - -class DiffusersLikeSchedulerOutput: - def __init__(self, prev_sample: torch.Tensor, pred_original_sample: torch.Tensor) -> None: - self.prev_sample = prev_sample - self.pred_original_sample = pred_original_sample - - -class DiffusersStyleDDPMScheduler(DDPMScheduler): - def step( - self, - model_output: torch.Tensor, - timestep: int, - sample: torch.Tensor, - generator: torch.Generator | None = None, - return_dict: bool = True, - ): - prev_sample, pred_original_sample = super().step( - model_output=model_output, timestep=timestep, sample=sample, generator=generator - ) - if return_dict: - return DiffusersLikeSchedulerOutput( - prev_sample=prev_sample, pred_original_sample=pred_original_sample - ) - return prev_sample, pred_original_sample - class TestDiffusionSamplingInferer(unittest.TestCase): @parameterized.expand(TEST_CASES) @@ -441,36 +415,45 @@ def test_sample_shape( ) self.assertEqual(sample.shape, input_shape) - @parameterized.expand(TEST_CASES_DIFFUSERS) - @skipUnless(has_einops, "Requires einops") - def test_diffusers_style_ddpm_sample_shape( - self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape - ): - if ae_model_type == "AutoencoderKL": - stage_1 = AutoencoderKL(**autoencoder_params) - else: - stage_1 = VQVAE(**autoencoder_params) - - if dm_model_type == "SPADEDiffusionModelUNet": - stage_2 = SPADEDiffusionModelUNet(**stage_2_params) - else: - stage_2 = DiffusionModelUNet(**stage_2_params) - + @skipUnless(has_einops and has_diffusers, "Requires einops and diffusers") + def test_diffusers_ddpm_sample_shape(self): device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1 = AutoencoderKL( + spatial_dims=2, + in_channels=1, + out_channels=1, + channels=(4, 4), + latent_channels=3, + attention_levels=[False, False], + num_res_blocks=1, + with_encoder_nonlocal_attn=False, + with_decoder_nonlocal_attn=False, + norm_num_groups=4, + ) + stage_2 = DiffusionModelUNet( + spatial_dims=2, + in_channels=3, + out_channels=3, + channels=[4, 4], + norm_num_groups=4, + attention_levels=[False, False], + num_res_blocks=1, + num_head_channels=4, + ) stage_1.to(device) stage_2.to(device) stage_1.eval() stage_2.eval() - noise = torch.randn(latent_shape).to(device) - scheduler = DiffusersStyleDDPMScheduler(num_train_timesteps=1000) + noise = torch.randn(1, 3, 4, 4).to(device) + scheduler = DiffusersDDPMScheduler(num_train_timesteps=10, beta_schedule="linear", prediction_type="epsilon") inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0) scheduler.set_timesteps(num_inference_steps=10) sample = inferer.sample( input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler ) - self.assertEqual(sample.shape, input_shape) + self.assertEqual(sample.shape, (1, 1, 8, 8)) @parameterized.expand(TEST_CASES) @skipUnless(has_einops, "Requires einops") From 414e42e24b02f886c026597f0ff8f5ff5b7f9420 Mon Sep 17 00:00:00 2001 From: ugbotueferhire Date: Sun, 31 May 2026 14:01:57 +0100 Subject: [PATCH 3/3] DCO Remediation Commit for ugbotueferhire I, ugbotueferhire , hereby add my Signed-off-by to this commit: 8f95fb140cfcf48b95ab5b53e8e17485754779e3 Signed-off-by: ugbotueferhire