From dc8b08a17624d01c28a9922c800e7350cf489a3d Mon Sep 17 00:00:00 2001 From: Rohit Jena Date: Mon, 26 Aug 2024 10:52:38 -0700 Subject: [PATCH 1/8] Added annealing inference python and bash scripts Signed-off-by: Rohit Jena --- examples/mm/stable_diffusion/anneal_sd.py | 213 +++++++++++++ examples/mm/stable_diffusion/anneal_sdxl.py | 279 ++++++++++++++++++ .../mm/stable_diffusion/launch_annealing.sh | 64 ++++ .../stable_diffusion/launch_annealing_xl.sh | 96 ++++++ .../megatron_sd_draftp_model.py | 94 ++++++ .../megatron_sdxl_draftp_model.py | 55 ++++ 6 files changed, 801 insertions(+) create mode 100644 examples/mm/stable_diffusion/anneal_sd.py create mode 100644 examples/mm/stable_diffusion/anneal_sdxl.py create mode 100644 examples/mm/stable_diffusion/launch_annealing.sh create mode 100644 examples/mm/stable_diffusion/launch_annealing_xl.sh diff --git a/examples/mm/stable_diffusion/anneal_sd.py b/examples/mm/stable_diffusion/anneal_sd.py new file mode 100644 index 000000000..f6274dcfa --- /dev/null +++ b/examples/mm/stable_diffusion/anneal_sd.py @@ -0,0 +1,213 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed +import torch.multiprocessing as mp +from megatron.core import parallel_state +from megatron.core.utils import divide +from omegaconf.omegaconf import OmegaConf, open_dict +from copy import deepcopy +import os +from functools import partial +from torch import nn +import numpy as np +from megatron.core.tensor_parallel.random import get_cuda_rng_tracker, get_data_parallel_rng_tracker_name +from PIL import Image +from packaging.version import Version + +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronStableDiffusionTrainerBuilder +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +from nemo_aligner.algorithms.supervised import SupervisedTrainer +from nemo_aligner.data.mm import text_webdataset +from nemo_aligner.data.nlp.builders import build_dataloader +from nemo_aligner.models.mm.stable_diffusion.image_text_rms import get_reward_model +from nemo_aligner.models.mm.stable_diffusion.megatron_sd_draftp_model import MegatronSDDRaFTPModel +from nemo_aligner.utils.distributed import Timer +from nemo_aligner.utils.train_script_utils import ( + CustomLoggerWrapper, + add_custom_checkpoint_callback, + extract_optimizer_scheduler_from_ptl_model, + init_distributed, + init_peft, + init_using_ptl, + retrieve_custom_trainer_state_dict, + temp_pop_from_config, +) + +mp.set_start_method("spawn", force=True) + + +def resolve_and_create_trainer(cfg, pop_trainer_key): + """resolve the cfg, remove the key before constructing the PTL trainer + and then restore it after + """ + OmegaConf.resolve(cfg) + with temp_pop_from_config(cfg.trainer, pop_trainer_key): + return MegatronStableDiffusionTrainerBuilder(cfg).create_trainer() + + +@hydra_runner(config_path="conf", config_name="draftp_sd") +def main(cfg) -> None: + + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + + # set cuda device for each process + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + torch.cuda.set_device(local_rank) + + cfg.exp_manager.create_wandb_logger = False + + if Version(torch.__version__) >= Version("1.12"): + torch.backends.cuda.matmul.allow_tf32 = True + cfg.model.data.train.dataset_path = [cfg.model.data.webdataset.local_root_path for _ in range(cfg.trainer.devices)] + cfg.model.data.validation.dataset_path = [ + cfg.model.data.webdataset.local_root_path for _ in range(cfg.trainer.devices) + ] + + trainer = resolve_and_create_trainer(cfg, "draftp_sd") + exp_manager(trainer, cfg.exp_manager) + logger = CustomLoggerWrapper(trainer.loggers) + # Instatiating the model here + ptl_model = MegatronSDDRaFTPModel(cfg.model, trainer).to(torch.cuda.current_device()) + init_peft(ptl_model, cfg.model) + + trainer_restore_path = trainer.ckpt_path + + if trainer_restore_path is not None: + custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer) + consumed_samples = custom_trainer_state_dict["consumed_samples"] + else: + custom_trainer_state_dict = None + consumed_samples = 0 + + init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False)) + + train_ds, validation_ds = text_webdataset.build_train_valid_datasets( + cfg.model.data, consumed_samples=consumed_samples + ) + # train_ds = [d["captions"] for d in list(train_ds)] + validation_ds = [d["captions"] for d in list(validation_ds)] + + val_dataloader = build_dataloader( + cfg, + dataset=validation_ds, + consumed_samples=consumed_samples, + mbs=cfg.model.micro_batch_size, + gbs=cfg.model.global_batch_size, + load_gbs=True, + ) + + init_using_ptl(trainer, ptl_model, val_dataloader, validation_ds) + + optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) + + ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) + + logger.log_hyperparams(OmegaConf.to_container(cfg)) + + reward_model = get_reward_model(cfg.rm, mbs=cfg.model.micro_batch_size, gbs=cfg.model.global_batch_size) + ptl_model.reward_model = reward_model + + ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) + timer = Timer(cfg.exp_manager.get("max_time_per_run", "0:12:00:00")) + + draft_p_trainer = SupervisedTrainer( + cfg=cfg.trainer.draftp_sd, + model=ptl_model, + optimizer=optimizer, + scheduler=scheduler, + train_dataloader=val_dataloader, + val_dataloader=val_dataloader, + test_dataloader=[], + logger=logger, + ckpt_callback=ckpt_callback, + run_timer=timer, + ) + + if custom_trainer_state_dict is not None: + draft_p_trainer.load_state_dict(custom_trainer_state_dict) + + # Run annealed guidance + if cfg.get("prompt") is not None: + logging.info(f"Override val dataset with custom prompt: {cfg.prompt}") + val_dataloader = [[cfg.prompt]] + + wt_types = cfg.get("weight_type", None) + if wt_types is None: + wt_types = ['base', 'draft', 'linear', 'power_2', 'power_4', 'step_0.6'] + else: + wt_types = wt_types.split(",") if isinstance(wt_types, str) else wt_types + logging.info(f"Running on types: {wt_types}") + + # run for all weight types + for wt_type in wt_types: + global_idx = 0 + if wt_type is None or wt_type == 'base': + # dummy function that assigns a value of 0 all the time + logging.info("using the base model") + wt_draft = lambda sigma, sigma_next, i, total: 0 + else: + if wt_type == 'linear': + wt_draft = lambda sigma, sigma_next, i, total: i*1.0/total + elif wt_type == 'draft': + wt_draft = lambda sigma, sigma_next, i, total: 1 + elif wt_type.startswith('power'): # its of the form power_{power} + pow = float(wt_type.split("_")[1]) + wt_draft = lambda sigma, sigma_next, i, total: (i*1.0/total)**pow + elif wt_type.startswith("step"): # use a step function (step_{p}) + frac = float(wt_type.split("_")[1]) + wt_draft = lambda sigma, sigma_next, i, total: float((i*1.0/total) >= frac) + else: + raise ValueError(f"invalid weighing type: {wt_type}") + logging.info(f"using weighing type for annealed outputs: {wt_type}.") + + # initialize generator + gen = torch.Generator(device='cpu') + gen.manual_seed((1243 + 1247837 * local_rank)%(int(2**32 - 1))) + os.makedirs(f"./annealed_outputs_sd_{wt_type}/", exist_ok=True) + + for batch in val_dataloader: + batch_size = len(batch) + with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()): + latents = torch.randn( + [ + batch_size, + ptl_model.in_channels, + ptl_model.height // ptl_model.downsampling_factor, + ptl_model.width // ptl_model.downsampling_factor, + ], + generator=gen, + ).to(torch.cuda.current_device()) + images = ptl_model.annealed_guidance(batch, latents, weighing_fn=wt_draft) + images = images.permute(0, 2, 3, 1).detach().float().cpu().numpy().astype(np.uint8) # outputs are already scaled from [0, 255] + # save to pil + for i in range(images.shape[0]): + i = i + global_idx + img_path = f"annealed_outputs_sd_{wt_type}/img_{i:05d}_{local_rank:02d}.png" + prompt_path = f"annealed_outputs_sd_{wt_type}/prompt_{i:05d}_{local_rank:02d}.txt" + Image.fromarray(images[i]).save(img_path) + with open(prompt_path, "w") as fi: + fi.write(batch[i]) + # increment global index + global_idx += batch_size + logging.info("Saved all images.") + + +if __name__ == "__main__": + main() diff --git a/examples/mm/stable_diffusion/anneal_sdxl.py b/examples/mm/stable_diffusion/anneal_sdxl.py new file mode 100644 index 000000000..3b64059a0 --- /dev/null +++ b/examples/mm/stable_diffusion/anneal_sdxl.py @@ -0,0 +1,279 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed +import torch.multiprocessing as mp +from megatron.core import parallel_state +from megatron.core.utils import divide +from omegaconf.omegaconf import OmegaConf, open_dict +from copy import deepcopy +import os +from functools import partial +from torch import nn +import numpy as np +from megatron.core.tensor_parallel.random import get_cuda_rng_tracker, get_data_parallel_rng_tracker_name +from PIL import Image + +# from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronStableDiffusionTrainerBuilder +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +from nemo_aligner.algorithms.supervised import SupervisedTrainer +from nemo_aligner.data.mm import text_webdataset +from nemo_aligner.data.nlp.builders import build_dataloader +from nemo_aligner.models.mm.stable_diffusion.image_text_rms import get_reward_model +from nemo_aligner.models.mm.stable_diffusion.megatron_sdxl_draftp_model import MegatronSDXLDRaFTPModel +from nemo_aligner.utils.distributed import Timer +from packaging.version import Version +from nemo_aligner.utils.train_script_utils import ( + CustomLoggerWrapper, + add_custom_checkpoint_callback, + extract_optimizer_scheduler_from_ptl_model, + init_distributed, + init_peft, + init_using_ptl, + retrieve_custom_trainer_state_dict, + temp_pop_from_config, +) +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import ( + LatentDiffusion, + MegatronLatentDiffusion, +) +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.diffusion_engine import MegatronDiffusionEngine, DiffusionEngine +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPFSDPStrategy +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel import UNetModel, ResBlock, SpatialTransformer, TimestepEmbedSequential +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder import AutoencoderKL, AutoencoderKLInferenceWrapper +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.model import Encoder, Decoder, ResnetBlock, AttnBlock +from nemo_aligner.models.mm.stable_diffusion.image_text_rms import MegatronCLIPRewardModel +from nemo.collections.multimodal.modules.stable_diffusion.encoders.modules import FrozenOpenCLIPEmbedder, FrozenOpenCLIPEmbedder2, FrozenCLIPEmbedder +from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ParallelLinearAdapter +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +# checkpointing +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (checkpoint_wrapper, CheckpointImpl, apply_activation_checkpointing) + +mp.set_start_method("spawn", force=True) + +class MegatronStableDiffusionTrainerBuilder(MegatronTrainerBuilder): + """Builder for SD model Trainer with overrides.""" + def _training_strategy(self) -> NLPDDPStrategy: + """ + Returns a DDP or a FSDP strategy passed to Trainer.strategy. Copied from `sd_xl_train.py` + """ + if self.cfg.model.get('fsdp', False): + logging.info("FSDP.") + assert ( + not self.cfg.model.optim.get('name') == 'distributed_fused_adam' + ), 'Distributed optimizer cannot be used with FSDP.' + if self.cfg.model.get('megatron_amp_O2', False): + logging.info('Torch FSDP is not compatible with O2 precision recipe. Setting O2 `False`.') + self.cfg.model.megatron_amp_O2 = False + + # Check if its a full-finetuning or PEFT + return NLPFSDPStrategy( + limit_all_gathers=self.cfg.model.get('fsdp_limit_all_gathers', True), + sharding_strategy=self.cfg.model.get('fsdp_sharding_strategy', 'full'), + cpu_offload=self.cfg.model.get('fsdp_cpu_offload', False), # offload on is not supported + grad_reduce_dtype=self.cfg.model.get('fsdp_grad_reduce_dtype', 32), + precision=self.cfg.trainer.precision, + ## nn Sequential is supposed to capture the `t_embed`, `label_emb`, `out` layers in the unet + extra_fsdp_wrap_module={UNetModel,TimestepEmbedSequential,Decoder,ResnetBlock,AttnBlock,nn.Sequential,\ + MegatronCLIPRewardModel,FrozenOpenCLIPEmbedder,FrozenOpenCLIPEmbedder2,FrozenCLIPEmbedder,\ + ParallelLinearAdapter}, + # extra_fsdp_wrap_module={UNetModel,TimestepEmbedSequential,Decoder,ResnetBlock,AttnBlock,SpatialTransformer,ResBlock,\ + use_orig_params=False, #self.cfg.model.inductor, + set_buffer_dtype=self.cfg.get('fsdp_set_buffer_dtype', None), + ) + + return NLPDDPStrategy( + no_ddp_communication_hook=(not self.cfg.model.get('ddp_overlap')), + gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view, + find_unused_parameters=False, + ) + + +def resolve_and_create_trainer(cfg, pop_trainer_key): + """resolve the cfg, remove the key before constructing the PTL trainer + and then restore it after + """ + OmegaConf.resolve(cfg) + with temp_pop_from_config(cfg.trainer, pop_trainer_key): + return MegatronStableDiffusionTrainerBuilder(cfg).create_trainer() + + +@hydra_runner(config_path="conf", config_name="draftp_sdxl") +def main(cfg) -> None: + + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + + # set cuda device for each process + local_rank = int(os.environ.get('LOCAL_RANK', 0)) + torch.cuda.set_device(local_rank) + + # turn off wandb logging + cfg.exp_manager.create_wandb_logger = False + + if Version(torch.__version__) >= Version("1.12"): + torch.backends.cuda.matmul.allow_tf32 = True + cfg.model.data.train.dataset_path = [ + cfg.model.data.webdataset.local_root_path for _ in range(cfg.trainer.devices * cfg.trainer.num_nodes) + ] + cfg.model.data.validation.dataset_path = [ + cfg.model.data.webdataset.local_root_path for _ in range(cfg.trainer.devices * cfg.trainer.num_nodes) + ] + + trainer = resolve_and_create_trainer(cfg, "draftp_sd") + exp_manager(trainer, cfg.exp_manager) + logger = CustomLoggerWrapper(trainer.loggers) + # Instatiating the model here + ptl_model = MegatronSDXLDRaFTPModel(cfg.model, trainer).to(torch.cuda.current_device()) + init_peft(ptl_model, cfg.model) # init peft + + trainer_restore_path = trainer.ckpt_path + + if trainer_restore_path is not None: + custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer) + consumed_samples = custom_trainer_state_dict["consumed_samples"] + else: + custom_trainer_state_dict = None + consumed_samples = 0 + + init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False)) + + # use the validation ds if needed + train_ds, validation_ds = text_webdataset.build_train_valid_datasets( + cfg.model.data, consumed_samples=consumed_samples + ) + validation_ds = [d["captions"] for d in list(validation_ds)] + + val_dataloader = build_dataloader( + cfg, + dataset=validation_ds, + consumed_samples=consumed_samples, + mbs=cfg.model.micro_batch_size, + gbs=cfg.model.global_batch_size, + load_gbs=True, + ) + init_using_ptl(trainer, ptl_model, val_dataloader, validation_ds) + + if cfg.model.get('activation_checkpointing', False): + # call activation checkpointing here + # checkpoint wrapper + logging.info("Applying activation checkpointing on UNet and Decoder.") + non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT) + def checkpoint_check_fn(module): + return isinstance(module, (Decoder, UNetModel, MegatronCLIPRewardModel)) + apply_activation_checkpointing(ptl_model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=checkpoint_check_fn) + + optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) + + ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) + + logger.log_hyperparams(OmegaConf.to_container(cfg)) + + torch.distributed.barrier() + + ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) + timer = Timer(cfg.exp_manager.get("max_time_per_run", "0:03:55:00")) # save a model just before 4 hours + + draft_p_trainer = SupervisedTrainer( + cfg=cfg.trainer.draftp_sd, + model=ptl_model, + optimizer=optimizer, + scheduler=scheduler, + train_dataloader=val_dataloader, + val_dataloader=val_dataloader, + test_dataloader=[], + logger=logger, + ckpt_callback=ckpt_callback, + run_timer=timer, + run_init_validation=True, + ) + + if custom_trainer_state_dict is not None: + draft_p_trainer.load_state_dict(custom_trainer_state_dict) + + torch.cuda.empty_cache() + + if cfg.get("prompt") is not None: + logging.info(f"Override val dataset with custom prompt: {cfg.prompt}") + val_dataloader = [[cfg.prompt]] + + wt_types = cfg.get("weight_type", None) + if wt_types is None: + wt_types = ['base', 'draft', 'linear', 'power_2', 'power_4', 'step_0.6'] + else: + wt_types = wt_types.split(",") if isinstance(wt_types, str) else wt_types + logging.info(f"Running on types: {wt_types}") + + # run for all weight types + for wt_type in wt_types: + global_idx = 0 + if wt_type is None or wt_type == 'base': + # dummy function that assigns a value of 0 all the time + logging.info("using the base model") + wt_draft = lambda sigma, sigma_next, i, total: 0 + else: + if wt_type == 'linear': + wt_draft = lambda sigma, sigma_next, i, total: i*1.0/total + elif wt_type == 'draft': + wt_draft = lambda sigma, sigma_next, i, total: 1 + elif wt_type.startswith('power'): # its of the form power_{power} + pow = float(wt_type.split("_")[1]) + wt_draft = lambda sigma, sigma_next, i, total: (i*1.0/total)**pow + elif wt_type.startswith("step"): # use a step function (step_{p}) + frac = float(wt_type.split("_")[1]) + wt_draft = lambda sigma, sigma_next, i, total: float((i*1.0/total) >= frac) + else: + raise ValueError(f"invalid weighing type: {wt_type}") + logging.info(f"using weighing type for annealed outputs: {wt_type}.") + + # initialize generator + gen = torch.Generator(device='cpu') + gen.manual_seed((1243 + 1247837 * local_rank)%(int(2**32 - 1))) + os.makedirs(f"./annealed_outputs_sdxl_{wt_type}/", exist_ok=True) + + for batch in val_dataloader: + batch_size = len(batch) + with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()): + latents = torch.randn( + [ + batch_size, + ptl_model.in_channels, + ptl_model.height // ptl_model.downsampling_factor, + ptl_model.width // ptl_model.downsampling_factor, + ], + generator=gen, + ).to(torch.cuda.current_device()) + images = ptl_model.annealed_guidance(batch, latents, weighing_fn=wt_draft) + images = images.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) # outputs are already scaled from [0, 255] + # save to pil + for i in range(images.shape[0]): + i = i + global_idx + img_path = f"annealed_outputs_sdxl_{wt_type}/img_{i:05d}_{local_rank:02d}.png" + prompt_path = f"annealed_outputs_sdxl_{wt_type}/prompt_{i:05d}_{local_rank:02d}.txt" + Image.fromarray(images[i]).save(img_path) + with open(prompt_path, "w") as fi: + fi.write(batch[i]) + # increment global index + global_idx += batch_size + logging.info("Saved all images.") + + +if __name__ == "__main__": + main() diff --git a/examples/mm/stable_diffusion/launch_annealing.sh b/examples/mm/stable_diffusion/launch_annealing.sh new file mode 100644 index 000000000..7d3f14fe4 --- /dev/null +++ b/examples/mm/stable_diffusion/launch_annealing.sh @@ -0,0 +1,64 @@ +#!/bin/bash +export PYTHONPATH=/opt/NeMo:/opt/nemo-aligner:$PYTHONPATH + +LR=${LR:=0.00025} +INF_STEPS=${INF_STEPS:=25} +KL_COEF=${KL_COEF:=0.1} +ETA=${ETA:=0.0} +DATASET=${DATASET:="pickapic50k.tar"} +MICRO_BS=${MICRO_BS:=2} +GRAD_ACCUMULATION=${GRAD_ACCUMULATION:=4} +PEFT=${PEFT:="sdlora"} +NUM_DEVICES=8 +GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION)) + +WANDB_NAME=SD_DRaFT+${JOBNAME}_lr_${LR}_data_${DATASET}_kl_${KL_COEF}_bs_${GLOBAL_BATCH_SIZE}_infstep_${INF_STEPS}_eta_${ETA}_peft_${PEFT} +WEBDATASET_PATH=/path/to/${DATASET} + +CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf" +CONFIG_NAME="draftp_sd" +UNET_CKPT="/path/to/unet.ckpt" +VAE_CKPT="/path/to/vae.ckpt" +RM_CKPT="/path/to/rewardmodel.nemo" +DIR_SAVE_CKPT_PATH=/opt/nemo-aligner/sd_draft_runs/draftp_saved_ckpts_${JOBNAME} + +# change this as an end-user +PROMPT=${PROMPT:-"Bananas growing on an apple tree"} + +mkdir -p ${DIR_SAVE_CKPT_PATH} + +EVAL_SCRIPT=${EVAL_SCRIPT:-"anneal_sd.py"} +set -x +DEVICE="0,1,2,3,4,5,6,7" +echo "Running DRaFT on ${DEVICE}" +export HYDRA_FULL_ERROR=1 \ +&& MASTER_PORT=15003 CUDA_VISIBLE_DEVICES="${DEVICE}" torchrun --nproc_per_node=8 /opt/nemo-aligner/examples/mm/stable_diffusion/${EVAL_SCRIPT} \ + --config-path=${CONFIG_PATH} \ + --config-name=${CONFIG_NAME} \ + model.optim.lr=${LR} \ + model.optim.weight_decay=0.005 \ + model.optim.sched.warmup_steps=0 \ + model.infer.inference_steps=${INF_STEPS} \ + model.infer.eta=0.0 \ + model.kl_coeff=${KL_COEF} \ + model.truncation_steps=1 \ + trainer.draftp_sd.max_epochs=1 \ + trainer.draftp_sd.max_steps=4000 \ + trainer.draftp_sd.save_interval=100 \ + model.unet_config.from_pretrained=${UNET_CKPT} \ + model.first_stage_config.from_pretrained=${VAE_CKPT} \ + model.micro_batch_size=${MICRO_BS} \ + model.global_batch_size=${GLOBAL_BATCH_SIZE} \ + model.peft.peft_scheme=${PEFT} \ + model.data.webdataset.local_root_path=$WEBDATASET_PATH \ + rm.model.restore_from_path=${RM_CKPT} \ + +prompt="${PROMPT}" \ + trainer.draftp_sd.val_check_interval=20 \ + trainer.draftp_sd.gradient_clip_val=10.0 \ + trainer.devices=${NUM_DEVICES} \ + rm.trainer.devices=${NUM_DEVICES} \ + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.name=${WANDB_NAME} \ + exp_manager.resume_if_exists=True \ + exp_manager.explicit_log_dir=${DIR_SAVE_CKPT_PATH} \ + exp_manager.wandb_logger_kwargs.project=${PROJECT} $ADDITIONAL_KWARGS diff --git a/examples/mm/stable_diffusion/launch_annealing_xl.sh b/examples/mm/stable_diffusion/launch_annealing_xl.sh new file mode 100644 index 000000000..2ac6590e3 --- /dev/null +++ b/examples/mm/stable_diffusion/launch_annealing_xl.sh @@ -0,0 +1,96 @@ +#!/bin/bash + +export PYTHONPATH=/opt/NeMo:/opt/nemo-aligner:$PYTHONPATH + +# setup multinodes +if [ ! -z "$NNODES" ]; then + NUMNODES=$NNODES; + NNODES="--nnodes $NNODES" +else + NUMNODES=1; + NNODES="" +fi + +echo "Setting nodes to $NNODES" +LR=${LR:=0.00025} +INF_STEPS=${INF_STEPS:=25} +KL_COEF=${KL_COEF:=0.1} +ETA=${ETA:=0.0} +DATASET=${DATASET:="pickapic50k.tar"} +MICRO_BS=${MICRO_BS:=1} +GRAD_ACCUMULATION=${GRAD_ACCUMULATION:=4} +PEFT=${PEFT:="sdlora"} +NUM_DEVICES=${NUM_DEVICES:=8} +GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION*NUMNODES)) +LOG_WANDB=${LOG_WANDB:="False"} +SLEEP=${SLEEP:=0} +if [ -z "${JOBNAME}" ]; then + echo "JOBNAME not specified, exiting" + exit +fi + +echo "additional kwargs: ${ADDITIONAL_KWARGS}" + +WANDB_NAME=SDXL_DRaFT+${JOBNAME}_lr_${LR}_data_${DATASET}_kl_${KL_COEF}_bs_${GLOBAL_BATCH_SIZE}_infstep_${INF_STEPS}_eta_${ETA}_peft_${PEFT} +WEBDATASET_PATH=/path/to/dataset + +CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf" +CONFIG_NAME=${CONFIG_NAME:="draftp_sdxl"} +UNET_CKPT="/path/to/unet.ckpt" +VAE_CKPT="/path/to/vae.ckpt" +RM_CKPT="/path/to/rewardmodel.nemo" +PROMPT=${PROMPT:="Bananas growing on an apple tree"} + +DIR_SAVE_CKPT_PATH=/opt/nemo-aligner/sdxl_draft_runs/draftp_xl_saved_ckpts_${JOBNAME} +if [ ! -z "${ACT_CKPT}" ]; then + ACT_CKPT="model.activation_checkpointing=$ACT_CKPT " + echo $ACT_CKPT +fi + +mkdir -p ${DIR_SAVE_CKPT_PATH} + +## Setup multinode parameters +if [ ! -z "${RDZV_ID}" ]; then + DISTRIBUTED_PARAMS="--rdzv_id $RDZV_ID --rdzv_backend c10d --rdzv_endpoint $head_node_ip:30030" +else + DISTRIBUTED_PARAMS="--master_port=30030" +fi +echo "Setting distributed params to $DISTRIBUTED_PARAMS" + +EVAL_SCRIPT=${EVAL_SCRIPT:-"anneal_sdxl.py"} +export DEVICE="0,1,2,3,4,5,6,7" && echo "Running DRaFT+ on ${DEVICE}" && export HYDRA_FULL_ERROR=1 +set -x +CUDA_VISIBLE_DEVICES="${DEVICE}" torchrun --nproc_per_node=$NUM_DEVICES $NNODES $DISTRIBUTED_PARAMS /opt/nemo-aligner/examples/mm/stable_diffusion/${EVAL_SCRIPT} \ + --config-path=${CONFIG_PATH} \ + --config-name=${CONFIG_NAME} \ + model.optim.lr=${LR} \ + model.optim.weight_decay=0.0005 \ + model.optim.sched.warmup_steps=0 \ + model.sampling.base.steps=${INF_STEPS} \ + model.kl_coeff=${KL_COEF} \ + model.truncation_steps=1 \ + trainer.draftp_sd.max_epochs=5 \ + trainer.draftp_sd.max_steps=10000 \ + trainer.draftp_sd.save_interval=200 \ + trainer.draftp_sd.val_check_interval=20 \ + trainer.draftp_sd.gradient_clip_val=10.0 \ + model.micro_batch_size=${MICRO_BS} \ + model.global_batch_size=${GLOBAL_BATCH_SIZE} \ + model.peft.peft_scheme=${PEFT} \ + model.data.webdataset.local_root_path=$WEBDATASET_PATH \ + rm.model.restore_from_path=${RM_CKPT} \ + trainer.devices=${NUM_DEVICES} \ + trainer.num_nodes=${NUMNODES} \ + rm.trainer.devices=${NUM_DEVICES} \ + rm.trainer.num_nodes=${NUMNODES} \ + +prompt="${PROMPT}" \ + exp_manager.create_wandb_logger=${LOG_WANDB} \ + model.first_stage_config.from_pretrained=${VAE_CKPT} \ + model.first_stage_config.from_NeMo=True \ + model.unet_config.from_pretrained=${UNET_CKPT} \ + model.unet_config.from_NeMo=True \ + $ACT_CKPT \ + exp_manager.wandb_logger_kwargs.name=${WANDB_NAME} \ + exp_manager.resume_if_exists=True \ + exp_manager.explicit_log_dir=${DIR_SAVE_CKPT_PATH} \ + exp_manager.wandb_logger_kwargs.project=${PROJECT} ${ADDITIONAL_KWARGS} diff --git a/nemo_aligner/models/mm/stable_diffusion/megatron_sd_draftp_model.py b/nemo_aligner/models/mm/stable_diffusion/megatron_sd_draftp_model.py index 746b95606..aa447586e 100644 --- a/nemo_aligner/models/mm/stable_diffusion/megatron_sd_draftp_model.py +++ b/nemo_aligner/models/mm/stable_diffusion/megatron_sd_draftp_model.py @@ -205,6 +205,100 @@ def log_visualization(self, prompts): return vae_decoder_output_draft_p, images, captions + @torch.no_grad() + def annealed_guidance(self, batch, x_T, weighing_fn=None): + ''' this function tries to perform sampling with a modified score function at each step which is an average + of the base model and the trained model ''' + if weighing_fn is None: + weighing_fn = lambda sigma1, sigma2, i, total: i*1.0/total + + with torch.cuda.amp.autocast( + enabled=self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + ): + batch_size = len(batch) + prev_img_draft_p = x_T + + device_draft_p = self.model.betas.device + + # init sampler and make schedule + sampler_draft_p = sampling_utils.initialize_sampler(self.model, self.sampler_type.upper()) + sampler_init = sampling_utils.initialize_sampler(self.init_model, self.sampler_type.upper()) + sampler_draft_p.make_schedule(ddim_num_steps=self.inference_steps, ddim_eta=self.eta, verbose=False) + sampler_init.make_schedule(ddim_num_steps=self.inference_steps, ddim_eta=self.eta, verbose=False) + + cond, u_cond = sampling_utils.encode_prompt( + self.model.cond_stage_model, batch, self.unconditional_guidance_scale + ) + + timesteps = sampler_draft_p.ddim_timesteps + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + + iterator = tqdm(time_range, desc=f"{sampler_draft_p.sampler.name} Sampler", total=total_steps) + + list_eps_draft_p = [] + list_eps_init = [] + truncation_steps = self.cfg.truncation_steps + + denoise_step_kwargs = { + "unconditional_guidance_scale": self.unconditional_guidance_scale, + "unconditional_conditioning": u_cond, + } + for i, step in enumerate(iterator): + + denoise_step_args = [total_steps, i, batch_size, device_draft_p, step, cond] + + # run ddim step for FT model + img_draft_p, pred_x0_draft_p, eps_t_draft_p = sampler_draft_p.single_ddim_denoise_step( + prev_img_draft_p.clone(), *denoise_step_args, **denoise_step_kwargs + ) + # run ddim step for base model + img_init, pred_x0_init, eps_t_init = sampler_init.single_ddim_denoise_step( + prev_img_draft_p.clone(), *denoise_step_args, **denoise_step_kwargs + ) + # sigmas_i = sampler_draft_p.ddim_sigmas[i] + # get weighing scheme + w_draft = float(weighing_fn(None, None, i, total_steps)) + w_base = 1 - w_draft + # combine weights + eps = w_base * eps_t_init + w_draft * eps_t_draft_p + # use this to get new image + index = total_steps - i - 1 + ts = torch.full((batch_size,), step, device=device_draft_p, dtype=torch.long) + # get new image + img_new_p, pred_x0_new_p = sampler_draft_p._get_x_prev_and_pred_x0( + False, + batch_size, + index, + device_draft_p, + prev_img_draft_p.clone(), + ts, + None, # model output, we shouldnt need this + eps, + False, + False, + 1.0, + 0.0, + ) + prev_img_draft_p = img_new_p + + last_states = [pred_x0_draft_p] + # stack + trajectories_predx0 = ( + torch.stack(last_states, dim=0).transpose(0, 1).contiguous().view(-1, *last_states[0].shape[1:]) + ) # B1CHW -> BCHW + + vae_decoder_output = [] + for i in range(0, batch_size, self.vae_batch_size): + image = self.model.differentiable_decode_first_stage(trajectories_predx0[i : i + self.vae_batch_size]) + vae_decoder_output.append(image) + + vae_decoder_output = torch.cat(vae_decoder_output, dim=0) + vae_decoder_output = torch.clip((vae_decoder_output + 1) / 2, 0, 1) * 255.0 + + return vae_decoder_output + + def generate( self, batch, x_T, ): diff --git a/nemo_aligner/models/mm/stable_diffusion/megatron_sdxl_draftp_model.py b/nemo_aligner/models/mm/stable_diffusion/megatron_sdxl_draftp_model.py index 9fc881a67..6584af318 100644 --- a/nemo_aligner/models/mm/stable_diffusion/megatron_sdxl_draftp_model.py +++ b/nemo_aligner/models/mm/stable_diffusion/megatron_sdxl_draftp_model.py @@ -42,6 +42,7 @@ get_unique_embedder_keys_from_conditioner, ) from nemo.collections.multimodal.parts.stable_diffusion.sdxl_pipeline import get_sampler_config +from nemo.collections.multimodal.parts.stable_diffusion.utils import append_dims, default, instantiate_from_config from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.utils import logging @@ -264,6 +265,60 @@ def generate_log_images(self, latents, batch, model): ] return log_img, log_reward, vae_decoder_output + @torch.no_grad() + def annealed_guidance(self, batch, x_T, weighing_fn=None): + ''' this function tries to perform sampling with a modified score function at each step which is an average + of the base model and the trained model ''' + if weighing_fn is None: + weighing_fn = lambda sigma1, sigma2, i, total: i*1.0/total + + with torch.cuda.amp.autocast( + enabled=self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + ): + batch_c = self.append_sdxl_size_keys(batch) + truncation_steps = self.cfg.truncation_steps + force_uc_zero_embeddings = ['txt', 'captions'] + sampler = self.sampler + # get conditional guidance keys + cond, uc = self.model.conditioner.get_unconditional_conditioning( + batch_c, batch_uc=None, force_uc_zero_embeddings=force_uc_zero_embeddings, + ) + additional_model_inputs = {} + # get denoisers for base and trained model + denoiser_draft = lambda input, sigma, c: self.model.denoiser(self.model.model, input, sigma, c, **additional_model_inputs) + base_model = self.init_model or self.model + denoiser_base = lambda input, sigma, c: base_model.denoiser(base_model.model, input, sigma, c, **additional_model_inputs) + # prep initial sampler config + x = x_T.clone() + num_steps = sampler.num_steps + x, s_in, sigmas, num_sigmas, cond, uc = sampler.prepare_sampling_loop(x, cond, uc, num_steps) + # last step doesnt count since there is no additional sigma + total_steps = num_sigmas-1 + iterator = tqdm(range(num_sigmas-1), desc=f"{sampler.__class__.__name__} Sampler", total=total_steps) + base_model = self.init_model or self.model + for i in iterator: + gamma = sampler.get_gamma(sigmas, num_sigmas, i) + # with context(set_draft_grad_flag): + # just run the sampling without storing any grad + x_next_draft, eps_draft = sampler.sampler_step(s_in * sigmas[i], s_in * sigmas[i+1], denoiser_draft, x.clone(), cond, uc, gamma, return_noise=True) + # get base model + with adapter_control(base_model): + _, eps_init = sampler.sampler_step(s_in * sigmas[i], s_in * sigmas[i+1], denoiser_base, x.clone(), cond, uc, gamma, return_noise=True) + # get weighing scheme + w_draft = float(weighing_fn(sigmas[i], sigmas[i+1], i, total_steps)) + w_base = 1 - w_draft + # combine weights + eps = w_base * eps_init + w_draft * eps_draft + dt = append_dims(s_in * sigmas[i+1] - s_in * sigmas[i], x.ndim) + euler_step = sampler.euler_step(x, eps, dt) + # get next x + x = sampler.possible_correction_step(euler_step, x.clone(), eps, dt, s_in * sigmas[i+1], denoiser_draft, cond, uc) + iterator.set_description(f"iteration: {i}/{total_steps}, w_base={w_base:06f}") + # decode the latent + image = self.model.differentiable_decode_first_stage(x) + image = torch.clamp((image + 1.0)/2.0, min=0.0, max=1.0) * 255.0 + return image + @torch.no_grad() def log_visualization(self, prompts): batch_size = len(prompts) From 78f3cb992a6194788e9e195474fee34922a8aec4 Mon Sep 17 00:00:00 2001 From: Rohit Jena Date: Mon, 26 Aug 2024 15:42:56 -0700 Subject: [PATCH 2/8] updated draftp docs and annealing bash scripts Signed-off-by: Rohit Jena --- docs/user-guide/draftp.rst | 36 +++++++++++++++---- .../mm/stable_diffusion/launch_annealing.sh | 5 +-- .../stable_diffusion/launch_annealing_xl.sh | 10 +----- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/docs/user-guide/draftp.rst b/docs/user-guide/draftp.rst index 3c67b8553..231aebb43 100644 --- a/docs/user-guide/draftp.rst +++ b/docs/user-guide/draftp.rst @@ -58,7 +58,7 @@ You can then run the following snipet to convert it to a ``.tar`` file: Reward Model ############ -Currently, we only have support for `Pickscore `__ reward model. Since Pickscore is a CLIP-based model, +Currently, we only have support for `Pickscore-style `__ reward models (PickScore/HPSv2). Since Pickscore is a CLIP-based model, you can use the `conversion script `__ from NeMo to convert it from huggingface to NeMo. DRaFT+ Training @@ -81,8 +81,9 @@ To launch reward model training, you must have checkpoints for `UNet `__ and `sd_lora_infer.py `__ scripts from the NeMo codebase. The generated images with the fine-tuned model should have -better prompt alignment and aesthetic quality. \ No newline at end of file +better prompt alignment and aesthetic quality. + +User controllable finetuning with Annealed Importance Guidance (AIG) +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +AIG provides the inference-time flexibility to interpolate between the base Stable Diffusion model (with low rewards and high diversity) and DRaFT-finetuned model (with high rewards and low diversity) to obtain images with high rewards and high diversity. AIG inference is easily done by specifying comma-separated `weight_type` strategies to interpolate between the base and finetuned model. + +.. tab-set:: + .. tab-item:: Terminal + :sync: key2 + + Weight type of `base` uses the base model for AIG, `draft` uses the finetuned model (no interpolation is done in either case). + Weight type of the form `power_` interpolates using an exponential decay specified in the AIG paper. + + To run AIG inference on the terminal directly: + + .. code-block:: bash + + SCRIPT="launch_annealing.sh" # or "launch_annealing_xl.sh" + DIR_SAVE_CKPT_PATH=/path/to/explicit_log_dir PROMPT="An astronaut sitting on a swing" ADDITIONAL_KWARGS="+weight_type='draft,base,power_2.0'" bash $SCRIPT + diff --git a/examples/mm/stable_diffusion/launch_annealing.sh b/examples/mm/stable_diffusion/launch_annealing.sh index 7d3f14fe4..7d65d65e5 100644 --- a/examples/mm/stable_diffusion/launch_annealing.sh +++ b/examples/mm/stable_diffusion/launch_annealing.sh @@ -12,7 +12,7 @@ PEFT=${PEFT:="sdlora"} NUM_DEVICES=8 GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION)) -WANDB_NAME=SD_DRaFT+${JOBNAME}_lr_${LR}_data_${DATASET}_kl_${KL_COEF}_bs_${GLOBAL_BATCH_SIZE}_infstep_${INF_STEPS}_eta_${ETA}_peft_${PEFT} +WANDB_NAME=SD_DRaFT_annealing WEBDATASET_PATH=/path/to/${DATASET} CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf" @@ -20,13 +20,10 @@ CONFIG_NAME="draftp_sd" UNET_CKPT="/path/to/unet.ckpt" VAE_CKPT="/path/to/vae.ckpt" RM_CKPT="/path/to/rewardmodel.nemo" -DIR_SAVE_CKPT_PATH=/opt/nemo-aligner/sd_draft_runs/draftp_saved_ckpts_${JOBNAME} # change this as an end-user PROMPT=${PROMPT:-"Bananas growing on an apple tree"} -mkdir -p ${DIR_SAVE_CKPT_PATH} - EVAL_SCRIPT=${EVAL_SCRIPT:-"anneal_sd.py"} set -x DEVICE="0,1,2,3,4,5,6,7" diff --git a/examples/mm/stable_diffusion/launch_annealing_xl.sh b/examples/mm/stable_diffusion/launch_annealing_xl.sh index 2ac6590e3..b88bd021f 100644 --- a/examples/mm/stable_diffusion/launch_annealing_xl.sh +++ b/examples/mm/stable_diffusion/launch_annealing_xl.sh @@ -23,15 +23,10 @@ PEFT=${PEFT:="sdlora"} NUM_DEVICES=${NUM_DEVICES:=8} GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION*NUMNODES)) LOG_WANDB=${LOG_WANDB:="False"} -SLEEP=${SLEEP:=0} -if [ -z "${JOBNAME}" ]; then - echo "JOBNAME not specified, exiting" - exit -fi echo "additional kwargs: ${ADDITIONAL_KWARGS}" -WANDB_NAME=SDXL_DRaFT+${JOBNAME}_lr_${LR}_data_${DATASET}_kl_${KL_COEF}_bs_${GLOBAL_BATCH_SIZE}_infstep_${INF_STEPS}_eta_${ETA}_peft_${PEFT} +WANDB_NAME=SDXL_Draft_annealing WEBDATASET_PATH=/path/to/dataset CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf" @@ -41,14 +36,11 @@ VAE_CKPT="/path/to/vae.ckpt" RM_CKPT="/path/to/rewardmodel.nemo" PROMPT=${PROMPT:="Bananas growing on an apple tree"} -DIR_SAVE_CKPT_PATH=/opt/nemo-aligner/sdxl_draft_runs/draftp_xl_saved_ckpts_${JOBNAME} if [ ! -z "${ACT_CKPT}" ]; then ACT_CKPT="model.activation_checkpointing=$ACT_CKPT " echo $ACT_CKPT fi -mkdir -p ${DIR_SAVE_CKPT_PATH} - ## Setup multinode parameters if [ ! -z "${RDZV_ID}" ]; then DISTRIBUTED_PARAMS="--rdzv_id $RDZV_ID --rdzv_backend c10d --rdzv_endpoint $head_node_ip:30030" From 3abca685cb731588a025aea2ad2939a44543428e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 22:44:28 +0000 Subject: [PATCH 3/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/mm/stable_diffusion/anneal_sd.py | 47 ++--- examples/mm/stable_diffusion/anneal_sdxl.py | 167 +++++++++++------- .../megatron_sd_draftp_model.py | 11 +- .../megatron_sdxl_draftp_model.py | 54 ++++-- 4 files changed, 174 insertions(+), 105 deletions(-) diff --git a/examples/mm/stable_diffusion/anneal_sd.py b/examples/mm/stable_diffusion/anneal_sd.py index f6274dcfa..76fd13fca 100644 --- a/examples/mm/stable_diffusion/anneal_sd.py +++ b/examples/mm/stable_diffusion/anneal_sd.py @@ -12,20 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +from copy import deepcopy +from functools import partial + +import numpy as np import torch import torch.distributed import torch.multiprocessing as mp from megatron.core import parallel_state +from megatron.core.tensor_parallel.random import get_cuda_rng_tracker, get_data_parallel_rng_tracker_name from megatron.core.utils import divide from omegaconf.omegaconf import OmegaConf, open_dict -from copy import deepcopy -import os -from functools import partial -from torch import nn -import numpy as np -from megatron.core.tensor_parallel.random import get_cuda_rng_tracker, get_data_parallel_rng_tracker_name -from PIL import Image from packaging.version import Version +from PIL import Image +from torch import nn from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronStableDiffusionTrainerBuilder from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP @@ -68,7 +69,7 @@ def main(cfg) -> None: logging.info(f"\n{OmegaConf.to_yaml(cfg)}") # set cuda device for each process - local_rank = int(os.environ.get('LOCAL_RANK', 0)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) cfg.exp_manager.create_wandb_logger = False @@ -147,10 +148,10 @@ def main(cfg) -> None: if cfg.get("prompt") is not None: logging.info(f"Override val dataset with custom prompt: {cfg.prompt}") val_dataloader = [[cfg.prompt]] - + wt_types = cfg.get("weight_type", None) if wt_types is None: - wt_types = ['base', 'draft', 'linear', 'power_2', 'power_4', 'step_0.6'] + wt_types = ["base", "draft", "linear", "power_2", "power_4", "step_0.6"] else: wt_types = wt_types.split(",") if isinstance(wt_types, str) else wt_types logging.info(f"Running on types: {wt_types}") @@ -158,28 +159,28 @@ def main(cfg) -> None: # run for all weight types for wt_type in wt_types: global_idx = 0 - if wt_type is None or wt_type == 'base': + if wt_type is None or wt_type == "base": # dummy function that assigns a value of 0 all the time logging.info("using the base model") wt_draft = lambda sigma, sigma_next, i, total: 0 else: - if wt_type == 'linear': - wt_draft = lambda sigma, sigma_next, i, total: i*1.0/total - elif wt_type == 'draft': + if wt_type == "linear": + wt_draft = lambda sigma, sigma_next, i, total: i * 1.0 / total + elif wt_type == "draft": wt_draft = lambda sigma, sigma_next, i, total: 1 - elif wt_type.startswith('power'): # its of the form power_{power} + elif wt_type.startswith("power"): # its of the form power_{power} pow = float(wt_type.split("_")[1]) - wt_draft = lambda sigma, sigma_next, i, total: (i*1.0/total)**pow - elif wt_type.startswith("step"): # use a step function (step_{p}) + wt_draft = lambda sigma, sigma_next, i, total: (i * 1.0 / total) ** pow + elif wt_type.startswith("step"): # use a step function (step_{p}) frac = float(wt_type.split("_")[1]) - wt_draft = lambda sigma, sigma_next, i, total: float((i*1.0/total) >= frac) + wt_draft = lambda sigma, sigma_next, i, total: float((i * 1.0 / total) >= frac) else: raise ValueError(f"invalid weighing type: {wt_type}") logging.info(f"using weighing type for annealed outputs: {wt_type}.") # initialize generator - gen = torch.Generator(device='cpu') - gen.manual_seed((1243 + 1247837 * local_rank)%(int(2**32 - 1))) + gen = torch.Generator(device="cpu") + gen.manual_seed((1243 + 1247837 * local_rank) % (int(2 ** 32 - 1))) os.makedirs(f"./annealed_outputs_sd_{wt_type}/", exist_ok=True) for batch in val_dataloader: @@ -195,7 +196,9 @@ def main(cfg) -> None: generator=gen, ).to(torch.cuda.current_device()) images = ptl_model.annealed_guidance(batch, latents, weighing_fn=wt_draft) - images = images.permute(0, 2, 3, 1).detach().float().cpu().numpy().astype(np.uint8) # outputs are already scaled from [0, 255] + images = ( + images.permute(0, 2, 3, 1).detach().float().cpu().numpy().astype(np.uint8) + ) # outputs are already scaled from [0, 255] # save to pil for i in range(images.shape[0]): i = i + global_idx @@ -206,7 +209,7 @@ def main(cfg) -> None: fi.write(batch[i]) # increment global index global_idx += batch_size - logging.info("Saved all images.") + logging.info("Saved all images.") if __name__ == "__main__": diff --git a/examples/mm/stable_diffusion/anneal_sdxl.py b/examples/mm/stable_diffusion/anneal_sdxl.py index 3b64059a0..b49bb0497 100644 --- a/examples/mm/stable_diffusion/anneal_sdxl.py +++ b/examples/mm/stable_diffusion/anneal_sdxl.py @@ -12,19 +12,62 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +from copy import deepcopy +from functools import partial + +import numpy as np import torch import torch.distributed import torch.multiprocessing as mp from megatron.core import parallel_state +from megatron.core.tensor_parallel.random import get_cuda_rng_tracker, get_data_parallel_rng_tracker_name from megatron.core.utils import divide from omegaconf.omegaconf import OmegaConf, open_dict -from copy import deepcopy -import os -from functools import partial -from torch import nn -import numpy as np -from megatron.core.tensor_parallel.random import get_cuda_rng_tracker, get_data_parallel_rng_tracker_name +from packaging.version import Version from PIL import Image +from torch import nn + +# checkpointing +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, +) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.diffusion_engine import ( + DiffusionEngine, + MegatronDiffusionEngine, +) +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder import ( + AutoencoderKL, + AutoencoderKLInferenceWrapper, +) +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import ( + LatentDiffusion, + MegatronLatentDiffusion, +) +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.model import ( + AttnBlock, + Decoder, + Encoder, + ResnetBlock, +) +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel import ( + ResBlock, + SpatialTransformer, + TimestepEmbedSequential, + UNetModel, +) +from nemo.collections.multimodal.modules.stable_diffusion.encoders.modules import ( + FrozenCLIPEmbedder, + FrozenOpenCLIPEmbedder, + FrozenOpenCLIPEmbedder2, +) +from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ParallelLinearAdapter +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPFSDPStrategy # from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronStableDiffusionTrainerBuilder from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP @@ -34,10 +77,9 @@ from nemo_aligner.algorithms.supervised import SupervisedTrainer from nemo_aligner.data.mm import text_webdataset from nemo_aligner.data.nlp.builders import build_dataloader -from nemo_aligner.models.mm.stable_diffusion.image_text_rms import get_reward_model +from nemo_aligner.models.mm.stable_diffusion.image_text_rms import MegatronCLIPRewardModel, get_reward_model from nemo_aligner.models.mm.stable_diffusion.megatron_sdxl_draftp_model import MegatronSDXLDRaFTPModel from nemo_aligner.utils.distributed import Timer -from packaging.version import Version from nemo_aligner.utils.train_script_utils import ( CustomLoggerWrapper, add_custom_checkpoint_callback, @@ -48,59 +90,54 @@ retrieve_custom_trainer_state_dict, temp_pop_from_config, ) -from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import ( - LatentDiffusion, - MegatronLatentDiffusion, -) -from nemo.collections.multimodal.models.text_to_image.stable_diffusion.diffusion_engine import MegatronDiffusionEngine, DiffusionEngine -from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder -from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPFSDPStrategy -from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel import UNetModel, ResBlock, SpatialTransformer, TimestepEmbedSequential -from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder import AutoencoderKL, AutoencoderKLInferenceWrapper -from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.model import Encoder, Decoder, ResnetBlock, AttnBlock -from nemo_aligner.models.mm.stable_diffusion.image_text_rms import MegatronCLIPRewardModel -from nemo.collections.multimodal.modules.stable_diffusion.encoders.modules import FrozenOpenCLIPEmbedder, FrozenOpenCLIPEmbedder2, FrozenCLIPEmbedder -from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ParallelLinearAdapter -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - -# checkpointing -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (checkpoint_wrapper, CheckpointImpl, apply_activation_checkpointing) mp.set_start_method("spawn", force=True) + class MegatronStableDiffusionTrainerBuilder(MegatronTrainerBuilder): """Builder for SD model Trainer with overrides.""" + def _training_strategy(self) -> NLPDDPStrategy: """ Returns a DDP or a FSDP strategy passed to Trainer.strategy. Copied from `sd_xl_train.py` """ - if self.cfg.model.get('fsdp', False): + if self.cfg.model.get("fsdp", False): logging.info("FSDP.") assert ( - not self.cfg.model.optim.get('name') == 'distributed_fused_adam' - ), 'Distributed optimizer cannot be used with FSDP.' - if self.cfg.model.get('megatron_amp_O2', False): - logging.info('Torch FSDP is not compatible with O2 precision recipe. Setting O2 `False`.') + not self.cfg.model.optim.get("name") == "distributed_fused_adam" + ), "Distributed optimizer cannot be used with FSDP." + if self.cfg.model.get("megatron_amp_O2", False): + logging.info("Torch FSDP is not compatible with O2 precision recipe. Setting O2 `False`.") self.cfg.model.megatron_amp_O2 = False - + # Check if its a full-finetuning or PEFT return NLPFSDPStrategy( - limit_all_gathers=self.cfg.model.get('fsdp_limit_all_gathers', True), - sharding_strategy=self.cfg.model.get('fsdp_sharding_strategy', 'full'), - cpu_offload=self.cfg.model.get('fsdp_cpu_offload', False), # offload on is not supported - grad_reduce_dtype=self.cfg.model.get('fsdp_grad_reduce_dtype', 32), + limit_all_gathers=self.cfg.model.get("fsdp_limit_all_gathers", True), + sharding_strategy=self.cfg.model.get("fsdp_sharding_strategy", "full"), + cpu_offload=self.cfg.model.get("fsdp_cpu_offload", False), # offload on is not supported + grad_reduce_dtype=self.cfg.model.get("fsdp_grad_reduce_dtype", 32), precision=self.cfg.trainer.precision, ## nn Sequential is supposed to capture the `t_embed`, `label_emb`, `out` layers in the unet - extra_fsdp_wrap_module={UNetModel,TimestepEmbedSequential,Decoder,ResnetBlock,AttnBlock,nn.Sequential,\ - MegatronCLIPRewardModel,FrozenOpenCLIPEmbedder,FrozenOpenCLIPEmbedder2,FrozenCLIPEmbedder,\ - ParallelLinearAdapter}, + extra_fsdp_wrap_module={ + UNetModel, + TimestepEmbedSequential, + Decoder, + ResnetBlock, + AttnBlock, + nn.Sequential, + MegatronCLIPRewardModel, + FrozenOpenCLIPEmbedder, + FrozenOpenCLIPEmbedder2, + FrozenCLIPEmbedder, + ParallelLinearAdapter, + }, # extra_fsdp_wrap_module={UNetModel,TimestepEmbedSequential,Decoder,ResnetBlock,AttnBlock,SpatialTransformer,ResBlock,\ - use_orig_params=False, #self.cfg.model.inductor, - set_buffer_dtype=self.cfg.get('fsdp_set_buffer_dtype', None), + use_orig_params=False, # self.cfg.model.inductor, + set_buffer_dtype=self.cfg.get("fsdp_set_buffer_dtype", None), ) return NLPDDPStrategy( - no_ddp_communication_hook=(not self.cfg.model.get('ddp_overlap')), + no_ddp_communication_hook=(not self.cfg.model.get("ddp_overlap")), gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view, find_unused_parameters=False, ) @@ -113,7 +150,7 @@ def resolve_and_create_trainer(cfg, pop_trainer_key): OmegaConf.resolve(cfg) with temp_pop_from_config(cfg.trainer, pop_trainer_key): return MegatronStableDiffusionTrainerBuilder(cfg).create_trainer() - + @hydra_runner(config_path="conf", config_name="draftp_sdxl") def main(cfg) -> None: @@ -122,7 +159,7 @@ def main(cfg) -> None: logging.info(f"\n{OmegaConf.to_yaml(cfg)}") # set cuda device for each process - local_rank = int(os.environ.get('LOCAL_RANK', 0)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) # turn off wandb logging @@ -142,7 +179,7 @@ def main(cfg) -> None: logger = CustomLoggerWrapper(trainer.loggers) # Instatiating the model here ptl_model = MegatronSDXLDRaFTPModel(cfg.model, trainer).to(torch.cuda.current_device()) - init_peft(ptl_model, cfg.model) # init peft + init_peft(ptl_model, cfg.model) # init peft trainer_restore_path = trainer.ckpt_path @@ -170,15 +207,19 @@ def main(cfg) -> None: load_gbs=True, ) init_using_ptl(trainer, ptl_model, val_dataloader, validation_ds) - - if cfg.model.get('activation_checkpointing', False): + + if cfg.model.get("activation_checkpointing", False): # call activation checkpointing here # checkpoint wrapper logging.info("Applying activation checkpointing on UNet and Decoder.") non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT) + def checkpoint_check_fn(module): return isinstance(module, (Decoder, UNetModel, MegatronCLIPRewardModel)) - apply_activation_checkpointing(ptl_model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=checkpoint_check_fn) + + apply_activation_checkpointing( + ptl_model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=checkpoint_check_fn + ) optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) @@ -189,7 +230,7 @@ def checkpoint_check_fn(module): torch.distributed.barrier() ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) - timer = Timer(cfg.exp_manager.get("max_time_per_run", "0:03:55:00")) # save a model just before 4 hours + timer = Timer(cfg.exp_manager.get("max_time_per_run", "0:03:55:00")) # save a model just before 4 hours draft_p_trainer = SupervisedTrainer( cfg=cfg.trainer.draftp_sd, @@ -213,10 +254,10 @@ def checkpoint_check_fn(module): if cfg.get("prompt") is not None: logging.info(f"Override val dataset with custom prompt: {cfg.prompt}") val_dataloader = [[cfg.prompt]] - + wt_types = cfg.get("weight_type", None) if wt_types is None: - wt_types = ['base', 'draft', 'linear', 'power_2', 'power_4', 'step_0.6'] + wt_types = ["base", "draft", "linear", "power_2", "power_4", "step_0.6"] else: wt_types = wt_types.split(",") if isinstance(wt_types, str) else wt_types logging.info(f"Running on types: {wt_types}") @@ -224,28 +265,28 @@ def checkpoint_check_fn(module): # run for all weight types for wt_type in wt_types: global_idx = 0 - if wt_type is None or wt_type == 'base': + if wt_type is None or wt_type == "base": # dummy function that assigns a value of 0 all the time logging.info("using the base model") wt_draft = lambda sigma, sigma_next, i, total: 0 else: - if wt_type == 'linear': - wt_draft = lambda sigma, sigma_next, i, total: i*1.0/total - elif wt_type == 'draft': + if wt_type == "linear": + wt_draft = lambda sigma, sigma_next, i, total: i * 1.0 / total + elif wt_type == "draft": wt_draft = lambda sigma, sigma_next, i, total: 1 - elif wt_type.startswith('power'): # its of the form power_{power} + elif wt_type.startswith("power"): # its of the form power_{power} pow = float(wt_type.split("_")[1]) - wt_draft = lambda sigma, sigma_next, i, total: (i*1.0/total)**pow - elif wt_type.startswith("step"): # use a step function (step_{p}) + wt_draft = lambda sigma, sigma_next, i, total: (i * 1.0 / total) ** pow + elif wt_type.startswith("step"): # use a step function (step_{p}) frac = float(wt_type.split("_")[1]) - wt_draft = lambda sigma, sigma_next, i, total: float((i*1.0/total) >= frac) + wt_draft = lambda sigma, sigma_next, i, total: float((i * 1.0 / total) >= frac) else: raise ValueError(f"invalid weighing type: {wt_type}") logging.info(f"using weighing type for annealed outputs: {wt_type}.") # initialize generator - gen = torch.Generator(device='cpu') - gen.manual_seed((1243 + 1247837 * local_rank)%(int(2**32 - 1))) + gen = torch.Generator(device="cpu") + gen.manual_seed((1243 + 1247837 * local_rank) % (int(2 ** 32 - 1))) os.makedirs(f"./annealed_outputs_sdxl_{wt_type}/", exist_ok=True) for batch in val_dataloader: @@ -261,7 +302,9 @@ def checkpoint_check_fn(module): generator=gen, ).to(torch.cuda.current_device()) images = ptl_model.annealed_guidance(batch, latents, weighing_fn=wt_draft) - images = images.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) # outputs are already scaled from [0, 255] + images = ( + images.permute(0, 2, 3, 1).detach().cpu().numpy().astype(np.uint8) + ) # outputs are already scaled from [0, 255] # save to pil for i in range(images.shape[0]): i = i + global_idx @@ -272,7 +315,7 @@ def checkpoint_check_fn(module): fi.write(batch[i]) # increment global index global_idx += batch_size - logging.info("Saved all images.") + logging.info("Saved all images.") if __name__ == "__main__": diff --git a/nemo_aligner/models/mm/stable_diffusion/megatron_sd_draftp_model.py b/nemo_aligner/models/mm/stable_diffusion/megatron_sd_draftp_model.py index aa447586e..c18aba5fe 100644 --- a/nemo_aligner/models/mm/stable_diffusion/megatron_sd_draftp_model.py +++ b/nemo_aligner/models/mm/stable_diffusion/megatron_sd_draftp_model.py @@ -207,10 +207,10 @@ def log_visualization(self, prompts): @torch.no_grad() def annealed_guidance(self, batch, x_T, weighing_fn=None): - ''' this function tries to perform sampling with a modified score function at each step which is an average - of the base model and the trained model ''' + """ this function tries to perform sampling with a modified score function at each step which is an average + of the base model and the trained model """ if weighing_fn is None: - weighing_fn = lambda sigma1, sigma2, i, total: i*1.0/total + weighing_fn = lambda sigma1, sigma2, i, total: i * 1.0 / total with torch.cuda.amp.autocast( enabled=self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, @@ -273,7 +273,7 @@ def annealed_guidance(self, batch, x_T, weighing_fn=None): device_draft_p, prev_img_draft_p.clone(), ts, - None, # model output, we shouldnt need this + None, # model output, we shouldnt need this eps, False, False, @@ -286,7 +286,7 @@ def annealed_guidance(self, batch, x_T, weighing_fn=None): # stack trajectories_predx0 = ( torch.stack(last_states, dim=0).transpose(0, 1).contiguous().view(-1, *last_states[0].shape[1:]) - ) # B1CHW -> BCHW + ) # B1CHW -> BCHW vae_decoder_output = [] for i in range(0, batch_size, self.vae_batch_size): @@ -298,7 +298,6 @@ def annealed_guidance(self, batch, x_T, weighing_fn=None): return vae_decoder_output - def generate( self, batch, x_T, ): diff --git a/nemo_aligner/models/mm/stable_diffusion/megatron_sdxl_draftp_model.py b/nemo_aligner/models/mm/stable_diffusion/megatron_sdxl_draftp_model.py index 6584af318..a7f57e334 100644 --- a/nemo_aligner/models/mm/stable_diffusion/megatron_sdxl_draftp_model.py +++ b/nemo_aligner/models/mm/stable_diffusion/megatron_sdxl_draftp_model.py @@ -267,17 +267,17 @@ def generate_log_images(self, latents, batch, model): @torch.no_grad() def annealed_guidance(self, batch, x_T, weighing_fn=None): - ''' this function tries to perform sampling with a modified score function at each step which is an average - of the base model and the trained model ''' + """ this function tries to perform sampling with a modified score function at each step which is an average + of the base model and the trained model """ if weighing_fn is None: - weighing_fn = lambda sigma1, sigma2, i, total: i*1.0/total + weighing_fn = lambda sigma1, sigma2, i, total: i * 1.0 / total with torch.cuda.amp.autocast( enabled=self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, ): batch_c = self.append_sdxl_size_keys(batch) truncation_steps = self.cfg.truncation_steps - force_uc_zero_embeddings = ['txt', 'captions'] + force_uc_zero_embeddings = ["txt", "captions"] sampler = self.sampler # get conditional guidance keys cond, uc = self.model.conditioner.get_unconditional_conditioning( @@ -285,38 +285,62 @@ def annealed_guidance(self, batch, x_T, weighing_fn=None): ) additional_model_inputs = {} # get denoisers for base and trained model - denoiser_draft = lambda input, sigma, c: self.model.denoiser(self.model.model, input, sigma, c, **additional_model_inputs) + denoiser_draft = lambda input, sigma, c: self.model.denoiser( + self.model.model, input, sigma, c, **additional_model_inputs + ) base_model = self.init_model or self.model - denoiser_base = lambda input, sigma, c: base_model.denoiser(base_model.model, input, sigma, c, **additional_model_inputs) + denoiser_base = lambda input, sigma, c: base_model.denoiser( + base_model.model, input, sigma, c, **additional_model_inputs + ) # prep initial sampler config x = x_T.clone() num_steps = sampler.num_steps x, s_in, sigmas, num_sigmas, cond, uc = sampler.prepare_sampling_loop(x, cond, uc, num_steps) # last step doesnt count since there is no additional sigma - total_steps = num_sigmas-1 - iterator = tqdm(range(num_sigmas-1), desc=f"{sampler.__class__.__name__} Sampler", total=total_steps) + total_steps = num_sigmas - 1 + iterator = tqdm(range(num_sigmas - 1), desc=f"{sampler.__class__.__name__} Sampler", total=total_steps) base_model = self.init_model or self.model for i in iterator: gamma = sampler.get_gamma(sigmas, num_sigmas, i) # with context(set_draft_grad_flag): # just run the sampling without storing any grad - x_next_draft, eps_draft = sampler.sampler_step(s_in * sigmas[i], s_in * sigmas[i+1], denoiser_draft, x.clone(), cond, uc, gamma, return_noise=True) - # get base model + x_next_draft, eps_draft = sampler.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser_draft, + x.clone(), + cond, + uc, + gamma, + return_noise=True, + ) + # get base model with adapter_control(base_model): - _, eps_init = sampler.sampler_step(s_in * sigmas[i], s_in * sigmas[i+1], denoiser_base, x.clone(), cond, uc, gamma, return_noise=True) + _, eps_init = sampler.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser_base, + x.clone(), + cond, + uc, + gamma, + return_noise=True, + ) # get weighing scheme - w_draft = float(weighing_fn(sigmas[i], sigmas[i+1], i, total_steps)) + w_draft = float(weighing_fn(sigmas[i], sigmas[i + 1], i, total_steps)) w_base = 1 - w_draft # combine weights eps = w_base * eps_init + w_draft * eps_draft - dt = append_dims(s_in * sigmas[i+1] - s_in * sigmas[i], x.ndim) + dt = append_dims(s_in * sigmas[i + 1] - s_in * sigmas[i], x.ndim) euler_step = sampler.euler_step(x, eps, dt) # get next x - x = sampler.possible_correction_step(euler_step, x.clone(), eps, dt, s_in * sigmas[i+1], denoiser_draft, cond, uc) + x = sampler.possible_correction_step( + euler_step, x.clone(), eps, dt, s_in * sigmas[i + 1], denoiser_draft, cond, uc + ) iterator.set_description(f"iteration: {i}/{total_steps}, w_base={w_base:06f}") # decode the latent image = self.model.differentiable_decode_first_stage(x) - image = torch.clamp((image + 1.0)/2.0, min=0.0, max=1.0) * 255.0 + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) * 255.0 return image @torch.no_grad() From 1021800b25acdabceddb9dfd5573be886dae2f21 Mon Sep 17 00:00:00 2001 From: Rohit Jena Date: Thu, 5 Sep 2024 16:08:09 -0700 Subject: [PATCH 4/8] updated documentation and removed shell scripts Signed-off-by: Rohit Jena --- docs/user-guide/draftp.rst | 143 +++++++++++++++++- examples/mm/stable_diffusion/anneal_sd.py | 1 - examples/mm/stable_diffusion/anneal_sdxl.py | 4 +- .../mm/stable_diffusion/launch_annealing.sh | 61 -------- .../stable_diffusion/launch_annealing_xl.sh | 88 ----------- 5 files changed, 140 insertions(+), 157 deletions(-) delete mode 100644 examples/mm/stable_diffusion/launch_annealing.sh delete mode 100644 examples/mm/stable_diffusion/launch_annealing_xl.sh diff --git a/docs/user-guide/draftp.rst b/docs/user-guide/draftp.rst index 231aebb43..ce3215ff4 100644 --- a/docs/user-guide/draftp.rst +++ b/docs/user-guide/draftp.rst @@ -185,16 +185,151 @@ User controllable finetuning with Annealed Importance Guidance (AIG) AIG provides the inference-time flexibility to interpolate between the base Stable Diffusion model (with low rewards and high diversity) and DRaFT-finetuned model (with high rewards and low diversity) to obtain images with high rewards and high diversity. AIG inference is easily done by specifying comma-separated `weight_type` strategies to interpolate between the base and finetuned model. .. tab-set:: - .. tab-item:: Terminal + .. tab-item:: AIG on Stable Diffusion XL :sync: key2 Weight type of `base` uses the base model for AIG, `draft` uses the finetuned model (no interpolation is done in either case). Weight type of the form `power_` interpolates using an exponential decay specified in the AIG paper. - To run AIG inference on the terminal directly: + To run AIG inference on the terminal directly: + + .. code-block:: bash + + NUMNODES=1 + LR=${LR:=0.00025} + INF_STEPS=${INF_STEPS:=25} + KL_COEF=${KL_COEF:=0.1} + ETA=${ETA:=0.0} + DATASET=${DATASET:="pickapic50k.tar"} + MICRO_BS=${MICRO_BS:=1} + GRAD_ACCUMULATION=${GRAD_ACCUMULATION:=4} + PEFT=${PEFT:="sdlora"} + NUM_DEVICES=${NUM_DEVICES:=8} + GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION*NUMNODES)) + LOG_WANDB=${LOG_WANDB:="False"} + + echo "additional kwargs: ${ADDITIONAL_KWARGS}" + + WANDB_NAME=SDXL_Draft_annealing + WEBDATASET_PATH=/path/to/${DATASET} + + CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf" + CONFIG_NAME=${CONFIG_NAME:="draftp_sdxl"} + UNET_CKPT="/path/to/unet.ckpt" + VAE_CKPT="/path/to/vae.ckpt" + RM_CKPT="/path/to/reward_model.nemo" + PROMPT=${PROMPT:="Bananas growing on an apple tree"} + DIR_SAVE_CKPT_PATH=/path/to/explicit_log_dir + + if [ ! -z "${ACT_CKPT}" ]; then + ACT_CKPT="model.activation_checkpointing=$ACT_CKPT " + echo $ACT_CKPT + fi + + EVAL_SCRIPT=${EVAL_SCRIPT:-"anneal_sdxl.py"} + export DEVICE="0,1,2,3,4,5,6,7" && echo "Running DRaFT+ on ${DEVICE}" && export HYDRA_FULL_ERROR=1 + set -x + CUDA_VISIBLE_DEVICES="${DEVICE}" torchrun --nproc_per_node=$NUM_DEVICES /opt/nemo-aligner/examples/mm/stable_diffusion/${EVAL_SCRIPT} \ + --config-path=${CONFIG_PATH} \ + --config-name=${CONFIG_NAME} \ + model.optim.lr=${LR} \ + model.optim.weight_decay=0.0005 \ + model.optim.sched.warmup_steps=0 \ + model.sampling.base.steps=${INF_STEPS} \ + model.kl_coeff=${KL_COEF} \ + model.truncation_steps=1 \ + trainer.draftp_sd.max_epochs=5 \ + trainer.draftp_sd.max_steps=10000 \ + trainer.draftp_sd.save_interval=200 \ + trainer.draftp_sd.val_check_interval=20 \ + trainer.draftp_sd.gradient_clip_val=10.0 \ + model.micro_batch_size=${MICRO_BS} \ + model.global_batch_size=${GLOBAL_BATCH_SIZE} \ + model.peft.peft_scheme=${PEFT} \ + model.data.webdataset.local_root_path=$WEBDATASET_PATH \ + rm.model.restore_from_path=${RM_CKPT} \ + trainer.devices=${NUM_DEVICES} \ + trainer.num_nodes=${NUMNODES} \ + rm.trainer.devices=${NUM_DEVICES} \ + rm.trainer.num_nodes=${NUMNODES} \ + +prompt="${PROMPT}" \ + exp_manager.create_wandb_logger=${LOG_WANDB} \ + model.first_stage_config.from_pretrained=${VAE_CKPT} \ + model.first_stage_config.from_NeMo=True \ + model.unet_config.from_pretrained=${UNET_CKPT} \ + model.unet_config.from_NeMo=True \ + $ACT_CKPT \ + exp_manager.wandb_logger_kwargs.name=${WANDB_NAME} \ + exp_manager.resume_if_exists=True \ + exp_manager.explicit_log_dir=${DIR_SAVE_CKPT_PATH} \ + exp_manager.wandb_logger_kwargs.project=${PROJECT} +weight_type='draft,base,power_2.0' + + .. tab-item:: AIG on Stable Diffusion v1.1 - v1.5 + :sync: key + + Weight type of `base` uses the base model for AIG, `draft` uses the finetuned model (no interpolation is done in either case). + Weight type of the form `power_` interpolates using an exponential decay specified in the AIG paper. + + To run AIG inference on the terminal directly: .. code-block:: bash - SCRIPT="launch_annealing.sh" # or "launch_annealing_xl.sh" - DIR_SAVE_CKPT_PATH=/path/to/explicit_log_dir PROMPT="An astronaut sitting on a swing" ADDITIONAL_KWARGS="+weight_type='draft,base,power_2.0'" bash $SCRIPT + LR=${LR:=0.00025} + INF_STEPS=${INF_STEPS:=25} + KL_COEF=${KL_COEF:=0.1} + ETA=${ETA:=0.0} + DATASET=${DATASET:="pickapic50k.tar"} + MICRO_BS=${MICRO_BS:=2} + GRAD_ACCUMULATION=${GRAD_ACCUMULATION:=4} + PEFT=${PEFT:="sdlora"} + NUM_DEVICES=8 + GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION)) + + WANDB_NAME=SD_DRaFT_annealing + WEBDATASET_PATH=/path/to/${DATASET} + + CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf" + CONFIG_NAME="draftp_sd" + UNET_CKPT="/path/to/unet.ckpt" + VAE_CKPT="/path/to/vae.ckpt" + RM_CKPT="/path/to/rewardmodel.nemo" + + # change this as an end-user + PROMPT=${PROMPT:-"Bananas growing on an apple tree"} + + EVAL_SCRIPT=${EVAL_SCRIPT:-"anneal_sd.py"} + set -x + DEVICE="0,1,2,3,4,5,6,7" + echo "Running DRaFT on ${DEVICE}" + export HYDRA_FULL_ERROR=1 \ + && MASTER_PORT=15003 CUDA_VISIBLE_DEVICES="${DEVICE}" torchrun --nproc_per_node=${NUM_DEVICES} /opt/nemo-aligner/examples/mm/stable_diffusion/${EVAL_SCRIPT} \ + --config-path=${CONFIG_PATH} \ + --config-name=${CONFIG_NAME} \ + model.optim.lr=${LR} \ + model.optim.weight_decay=0.005 \ + model.optim.sched.warmup_steps=0 \ + model.infer.inference_steps=${INF_STEPS} \ + model.infer.eta=0.0 \ + model.kl_coeff=${KL_COEF} \ + model.truncation_steps=1 \ + trainer.draftp_sd.max_epochs=1 \ + trainer.draftp_sd.max_steps=4000 \ + trainer.draftp_sd.save_interval=100 \ + model.unet_config.from_pretrained=${UNET_CKPT} \ + model.first_stage_config.from_pretrained=${VAE_CKPT} \ + model.micro_batch_size=${MICRO_BS} \ + model.global_batch_size=${GLOBAL_BATCH_SIZE} \ + model.peft.peft_scheme=${PEFT} \ + model.data.webdataset.local_root_path=$WEBDATASET_PATH \ + rm.model.restore_from_path=${RM_CKPT} \ + +prompt="${PROMPT}" \ + trainer.draftp_sd.val_check_interval=20 \ + trainer.draftp_sd.gradient_clip_val=10.0 \ + trainer.devices=${NUM_DEVICES} \ + rm.trainer.devices=${NUM_DEVICES} \ + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.name=${WANDB_NAME} \ + exp_manager.resume_if_exists=True \ + exp_manager.explicit_log_dir=${DIR_SAVE_CKPT_PATH} \ + exp_manager.wandb_logger_kwargs.project=${PROJECT} +weight_type='draft,base,power_2.0' diff --git a/examples/mm/stable_diffusion/anneal_sd.py b/examples/mm/stable_diffusion/anneal_sd.py index 76fd13fca..86c9cb387 100644 --- a/examples/mm/stable_diffusion/anneal_sd.py +++ b/examples/mm/stable_diffusion/anneal_sd.py @@ -102,7 +102,6 @@ def main(cfg) -> None: train_ds, validation_ds = text_webdataset.build_train_valid_datasets( cfg.model.data, consumed_samples=consumed_samples ) - # train_ds = [d["captions"] for d in list(train_ds)] validation_ds = [d["captions"] for d in list(validation_ds)] val_dataloader = build_dataloader( diff --git a/examples/mm/stable_diffusion/anneal_sdxl.py b/examples/mm/stable_diffusion/anneal_sdxl.py index b49bb0497..515479ace 100644 --- a/examples/mm/stable_diffusion/anneal_sdxl.py +++ b/examples/mm/stable_diffusion/anneal_sdxl.py @@ -93,7 +93,6 @@ mp.set_start_method("spawn", force=True) - class MegatronStableDiffusionTrainerBuilder(MegatronTrainerBuilder): """Builder for SD model Trainer with overrides.""" @@ -131,8 +130,7 @@ def _training_strategy(self) -> NLPDDPStrategy: FrozenCLIPEmbedder, ParallelLinearAdapter, }, - # extra_fsdp_wrap_module={UNetModel,TimestepEmbedSequential,Decoder,ResnetBlock,AttnBlock,SpatialTransformer,ResBlock,\ - use_orig_params=False, # self.cfg.model.inductor, + use_orig_params=False, set_buffer_dtype=self.cfg.get("fsdp_set_buffer_dtype", None), ) diff --git a/examples/mm/stable_diffusion/launch_annealing.sh b/examples/mm/stable_diffusion/launch_annealing.sh deleted file mode 100644 index 7d65d65e5..000000000 --- a/examples/mm/stable_diffusion/launch_annealing.sh +++ /dev/null @@ -1,61 +0,0 @@ -#!/bin/bash -export PYTHONPATH=/opt/NeMo:/opt/nemo-aligner:$PYTHONPATH - -LR=${LR:=0.00025} -INF_STEPS=${INF_STEPS:=25} -KL_COEF=${KL_COEF:=0.1} -ETA=${ETA:=0.0} -DATASET=${DATASET:="pickapic50k.tar"} -MICRO_BS=${MICRO_BS:=2} -GRAD_ACCUMULATION=${GRAD_ACCUMULATION:=4} -PEFT=${PEFT:="sdlora"} -NUM_DEVICES=8 -GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION)) - -WANDB_NAME=SD_DRaFT_annealing -WEBDATASET_PATH=/path/to/${DATASET} - -CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf" -CONFIG_NAME="draftp_sd" -UNET_CKPT="/path/to/unet.ckpt" -VAE_CKPT="/path/to/vae.ckpt" -RM_CKPT="/path/to/rewardmodel.nemo" - -# change this as an end-user -PROMPT=${PROMPT:-"Bananas growing on an apple tree"} - -EVAL_SCRIPT=${EVAL_SCRIPT:-"anneal_sd.py"} -set -x -DEVICE="0,1,2,3,4,5,6,7" -echo "Running DRaFT on ${DEVICE}" -export HYDRA_FULL_ERROR=1 \ -&& MASTER_PORT=15003 CUDA_VISIBLE_DEVICES="${DEVICE}" torchrun --nproc_per_node=8 /opt/nemo-aligner/examples/mm/stable_diffusion/${EVAL_SCRIPT} \ - --config-path=${CONFIG_PATH} \ - --config-name=${CONFIG_NAME} \ - model.optim.lr=${LR} \ - model.optim.weight_decay=0.005 \ - model.optim.sched.warmup_steps=0 \ - model.infer.inference_steps=${INF_STEPS} \ - model.infer.eta=0.0 \ - model.kl_coeff=${KL_COEF} \ - model.truncation_steps=1 \ - trainer.draftp_sd.max_epochs=1 \ - trainer.draftp_sd.max_steps=4000 \ - trainer.draftp_sd.save_interval=100 \ - model.unet_config.from_pretrained=${UNET_CKPT} \ - model.first_stage_config.from_pretrained=${VAE_CKPT} \ - model.micro_batch_size=${MICRO_BS} \ - model.global_batch_size=${GLOBAL_BATCH_SIZE} \ - model.peft.peft_scheme=${PEFT} \ - model.data.webdataset.local_root_path=$WEBDATASET_PATH \ - rm.model.restore_from_path=${RM_CKPT} \ - +prompt="${PROMPT}" \ - trainer.draftp_sd.val_check_interval=20 \ - trainer.draftp_sd.gradient_clip_val=10.0 \ - trainer.devices=${NUM_DEVICES} \ - rm.trainer.devices=${NUM_DEVICES} \ - exp_manager.create_wandb_logger=True \ - exp_manager.wandb_logger_kwargs.name=${WANDB_NAME} \ - exp_manager.resume_if_exists=True \ - exp_manager.explicit_log_dir=${DIR_SAVE_CKPT_PATH} \ - exp_manager.wandb_logger_kwargs.project=${PROJECT} $ADDITIONAL_KWARGS diff --git a/examples/mm/stable_diffusion/launch_annealing_xl.sh b/examples/mm/stable_diffusion/launch_annealing_xl.sh deleted file mode 100644 index b88bd021f..000000000 --- a/examples/mm/stable_diffusion/launch_annealing_xl.sh +++ /dev/null @@ -1,88 +0,0 @@ -#!/bin/bash - -export PYTHONPATH=/opt/NeMo:/opt/nemo-aligner:$PYTHONPATH - -# setup multinodes -if [ ! -z "$NNODES" ]; then - NUMNODES=$NNODES; - NNODES="--nnodes $NNODES" -else - NUMNODES=1; - NNODES="" -fi - -echo "Setting nodes to $NNODES" -LR=${LR:=0.00025} -INF_STEPS=${INF_STEPS:=25} -KL_COEF=${KL_COEF:=0.1} -ETA=${ETA:=0.0} -DATASET=${DATASET:="pickapic50k.tar"} -MICRO_BS=${MICRO_BS:=1} -GRAD_ACCUMULATION=${GRAD_ACCUMULATION:=4} -PEFT=${PEFT:="sdlora"} -NUM_DEVICES=${NUM_DEVICES:=8} -GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION*NUMNODES)) -LOG_WANDB=${LOG_WANDB:="False"} - -echo "additional kwargs: ${ADDITIONAL_KWARGS}" - -WANDB_NAME=SDXL_Draft_annealing -WEBDATASET_PATH=/path/to/dataset - -CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf" -CONFIG_NAME=${CONFIG_NAME:="draftp_sdxl"} -UNET_CKPT="/path/to/unet.ckpt" -VAE_CKPT="/path/to/vae.ckpt" -RM_CKPT="/path/to/rewardmodel.nemo" -PROMPT=${PROMPT:="Bananas growing on an apple tree"} - -if [ ! -z "${ACT_CKPT}" ]; then - ACT_CKPT="model.activation_checkpointing=$ACT_CKPT " - echo $ACT_CKPT -fi - -## Setup multinode parameters -if [ ! -z "${RDZV_ID}" ]; then - DISTRIBUTED_PARAMS="--rdzv_id $RDZV_ID --rdzv_backend c10d --rdzv_endpoint $head_node_ip:30030" -else - DISTRIBUTED_PARAMS="--master_port=30030" -fi -echo "Setting distributed params to $DISTRIBUTED_PARAMS" - -EVAL_SCRIPT=${EVAL_SCRIPT:-"anneal_sdxl.py"} -export DEVICE="0,1,2,3,4,5,6,7" && echo "Running DRaFT+ on ${DEVICE}" && export HYDRA_FULL_ERROR=1 -set -x -CUDA_VISIBLE_DEVICES="${DEVICE}" torchrun --nproc_per_node=$NUM_DEVICES $NNODES $DISTRIBUTED_PARAMS /opt/nemo-aligner/examples/mm/stable_diffusion/${EVAL_SCRIPT} \ - --config-path=${CONFIG_PATH} \ - --config-name=${CONFIG_NAME} \ - model.optim.lr=${LR} \ - model.optim.weight_decay=0.0005 \ - model.optim.sched.warmup_steps=0 \ - model.sampling.base.steps=${INF_STEPS} \ - model.kl_coeff=${KL_COEF} \ - model.truncation_steps=1 \ - trainer.draftp_sd.max_epochs=5 \ - trainer.draftp_sd.max_steps=10000 \ - trainer.draftp_sd.save_interval=200 \ - trainer.draftp_sd.val_check_interval=20 \ - trainer.draftp_sd.gradient_clip_val=10.0 \ - model.micro_batch_size=${MICRO_BS} \ - model.global_batch_size=${GLOBAL_BATCH_SIZE} \ - model.peft.peft_scheme=${PEFT} \ - model.data.webdataset.local_root_path=$WEBDATASET_PATH \ - rm.model.restore_from_path=${RM_CKPT} \ - trainer.devices=${NUM_DEVICES} \ - trainer.num_nodes=${NUMNODES} \ - rm.trainer.devices=${NUM_DEVICES} \ - rm.trainer.num_nodes=${NUMNODES} \ - +prompt="${PROMPT}" \ - exp_manager.create_wandb_logger=${LOG_WANDB} \ - model.first_stage_config.from_pretrained=${VAE_CKPT} \ - model.first_stage_config.from_NeMo=True \ - model.unet_config.from_pretrained=${UNET_CKPT} \ - model.unet_config.from_NeMo=True \ - $ACT_CKPT \ - exp_manager.wandb_logger_kwargs.name=${WANDB_NAME} \ - exp_manager.resume_if_exists=True \ - exp_manager.explicit_log_dir=${DIR_SAVE_CKPT_PATH} \ - exp_manager.wandb_logger_kwargs.project=${PROJECT} ${ADDITIONAL_KWARGS} From 75a46dd73e38d04831b21b4aa41a1f525b79633f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Sep 2024 23:08:51 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/mm/stable_diffusion/anneal_sdxl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/mm/stable_diffusion/anneal_sdxl.py b/examples/mm/stable_diffusion/anneal_sdxl.py index 515479ace..64e032b12 100644 --- a/examples/mm/stable_diffusion/anneal_sdxl.py +++ b/examples/mm/stable_diffusion/anneal_sdxl.py @@ -93,6 +93,7 @@ mp.set_start_method("spawn", force=True) + class MegatronStableDiffusionTrainerBuilder(MegatronTrainerBuilder): """Builder for SD model Trainer with overrides.""" @@ -130,7 +131,7 @@ def _training_strategy(self) -> NLPDDPStrategy: FrozenCLIPEmbedder, ParallelLinearAdapter, }, - use_orig_params=False, + use_orig_params=False, set_buffer_dtype=self.cfg.get("fsdp_set_buffer_dtype", None), ) From de8862bc494cc611496a91e70013017efba56cc5 Mon Sep 17 00:00:00 2001 From: Rohit Jena Date: Fri, 6 Sep 2024 11:28:00 -0700 Subject: [PATCH 6/8] removing AIG for SD for now and related documentation - Ali will take care of the merging of the two scripts Also changed the weighting function, timer defaults, and relative paths for saving images Signed-off-by: Rohit Jena --- docs/user-guide/draftp.rst | 67 ------ examples/mm/stable_diffusion/anneal_sd.py | 215 -------------------- examples/mm/stable_diffusion/anneal_sdxl.py | 36 ++-- 3 files changed, 18 insertions(+), 300 deletions(-) delete mode 100644 examples/mm/stable_diffusion/anneal_sd.py diff --git a/docs/user-guide/draftp.rst b/docs/user-guide/draftp.rst index ce3215ff4..b6fdb8d3c 100644 --- a/docs/user-guide/draftp.rst +++ b/docs/user-guide/draftp.rst @@ -264,72 +264,5 @@ AIG provides the inference-time flexibility to interpolate between the base Stab exp_manager.explicit_log_dir=${DIR_SAVE_CKPT_PATH} \ exp_manager.wandb_logger_kwargs.project=${PROJECT} +weight_type='draft,base,power_2.0' - .. tab-item:: AIG on Stable Diffusion v1.1 - v1.5 - :sync: key - Weight type of `base` uses the base model for AIG, `draft` uses the finetuned model (no interpolation is done in either case). - Weight type of the form `power_` interpolates using an exponential decay specified in the AIG paper. - - To run AIG inference on the terminal directly: - - .. code-block:: bash - - LR=${LR:=0.00025} - INF_STEPS=${INF_STEPS:=25} - KL_COEF=${KL_COEF:=0.1} - ETA=${ETA:=0.0} - DATASET=${DATASET:="pickapic50k.tar"} - MICRO_BS=${MICRO_BS:=2} - GRAD_ACCUMULATION=${GRAD_ACCUMULATION:=4} - PEFT=${PEFT:="sdlora"} - NUM_DEVICES=8 - GLOBAL_BATCH_SIZE=$((MICRO_BS*NUM_DEVICES*GRAD_ACCUMULATION)) - - WANDB_NAME=SD_DRaFT_annealing - WEBDATASET_PATH=/path/to/${DATASET} - - CONFIG_PATH="/opt/nemo-aligner/examples/mm/stable_diffusion/conf" - CONFIG_NAME="draftp_sd" - UNET_CKPT="/path/to/unet.ckpt" - VAE_CKPT="/path/to/vae.ckpt" - RM_CKPT="/path/to/rewardmodel.nemo" - - # change this as an end-user - PROMPT=${PROMPT:-"Bananas growing on an apple tree"} - - EVAL_SCRIPT=${EVAL_SCRIPT:-"anneal_sd.py"} - set -x - DEVICE="0,1,2,3,4,5,6,7" - echo "Running DRaFT on ${DEVICE}" - export HYDRA_FULL_ERROR=1 \ - && MASTER_PORT=15003 CUDA_VISIBLE_DEVICES="${DEVICE}" torchrun --nproc_per_node=${NUM_DEVICES} /opt/nemo-aligner/examples/mm/stable_diffusion/${EVAL_SCRIPT} \ - --config-path=${CONFIG_PATH} \ - --config-name=${CONFIG_NAME} \ - model.optim.lr=${LR} \ - model.optim.weight_decay=0.005 \ - model.optim.sched.warmup_steps=0 \ - model.infer.inference_steps=${INF_STEPS} \ - model.infer.eta=0.0 \ - model.kl_coeff=${KL_COEF} \ - model.truncation_steps=1 \ - trainer.draftp_sd.max_epochs=1 \ - trainer.draftp_sd.max_steps=4000 \ - trainer.draftp_sd.save_interval=100 \ - model.unet_config.from_pretrained=${UNET_CKPT} \ - model.first_stage_config.from_pretrained=${VAE_CKPT} \ - model.micro_batch_size=${MICRO_BS} \ - model.global_batch_size=${GLOBAL_BATCH_SIZE} \ - model.peft.peft_scheme=${PEFT} \ - model.data.webdataset.local_root_path=$WEBDATASET_PATH \ - rm.model.restore_from_path=${RM_CKPT} \ - +prompt="${PROMPT}" \ - trainer.draftp_sd.val_check_interval=20 \ - trainer.draftp_sd.gradient_clip_val=10.0 \ - trainer.devices=${NUM_DEVICES} \ - rm.trainer.devices=${NUM_DEVICES} \ - exp_manager.create_wandb_logger=True \ - exp_manager.wandb_logger_kwargs.name=${WANDB_NAME} \ - exp_manager.resume_if_exists=True \ - exp_manager.explicit_log_dir=${DIR_SAVE_CKPT_PATH} \ - exp_manager.wandb_logger_kwargs.project=${PROJECT} +weight_type='draft,base,power_2.0' diff --git a/examples/mm/stable_diffusion/anneal_sd.py b/examples/mm/stable_diffusion/anneal_sd.py deleted file mode 100644 index 86c9cb387..000000000 --- a/examples/mm/stable_diffusion/anneal_sd.py +++ /dev/null @@ -1,215 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from copy import deepcopy -from functools import partial - -import numpy as np -import torch -import torch.distributed -import torch.multiprocessing as mp -from megatron.core import parallel_state -from megatron.core.tensor_parallel.random import get_cuda_rng_tracker, get_data_parallel_rng_tracker_name -from megatron.core.utils import divide -from omegaconf.omegaconf import OmegaConf, open_dict -from packaging.version import Version -from PIL import Image -from torch import nn - -from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronStableDiffusionTrainerBuilder -from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP -from nemo.core.config import hydra_runner -from nemo.utils import logging -from nemo.utils.exp_manager import exp_manager -from nemo_aligner.algorithms.supervised import SupervisedTrainer -from nemo_aligner.data.mm import text_webdataset -from nemo_aligner.data.nlp.builders import build_dataloader -from nemo_aligner.models.mm.stable_diffusion.image_text_rms import get_reward_model -from nemo_aligner.models.mm.stable_diffusion.megatron_sd_draftp_model import MegatronSDDRaFTPModel -from nemo_aligner.utils.distributed import Timer -from nemo_aligner.utils.train_script_utils import ( - CustomLoggerWrapper, - add_custom_checkpoint_callback, - extract_optimizer_scheduler_from_ptl_model, - init_distributed, - init_peft, - init_using_ptl, - retrieve_custom_trainer_state_dict, - temp_pop_from_config, -) - -mp.set_start_method("spawn", force=True) - - -def resolve_and_create_trainer(cfg, pop_trainer_key): - """resolve the cfg, remove the key before constructing the PTL trainer - and then restore it after - """ - OmegaConf.resolve(cfg) - with temp_pop_from_config(cfg.trainer, pop_trainer_key): - return MegatronStableDiffusionTrainerBuilder(cfg).create_trainer() - - -@hydra_runner(config_path="conf", config_name="draftp_sd") -def main(cfg) -> None: - - logging.info("\n\n************** Experiment configuration ***********") - logging.info(f"\n{OmegaConf.to_yaml(cfg)}") - - # set cuda device for each process - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - torch.cuda.set_device(local_rank) - - cfg.exp_manager.create_wandb_logger = False - - if Version(torch.__version__) >= Version("1.12"): - torch.backends.cuda.matmul.allow_tf32 = True - cfg.model.data.train.dataset_path = [cfg.model.data.webdataset.local_root_path for _ in range(cfg.trainer.devices)] - cfg.model.data.validation.dataset_path = [ - cfg.model.data.webdataset.local_root_path for _ in range(cfg.trainer.devices) - ] - - trainer = resolve_and_create_trainer(cfg, "draftp_sd") - exp_manager(trainer, cfg.exp_manager) - logger = CustomLoggerWrapper(trainer.loggers) - # Instatiating the model here - ptl_model = MegatronSDDRaFTPModel(cfg.model, trainer).to(torch.cuda.current_device()) - init_peft(ptl_model, cfg.model) - - trainer_restore_path = trainer.ckpt_path - - if trainer_restore_path is not None: - custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer) - consumed_samples = custom_trainer_state_dict["consumed_samples"] - else: - custom_trainer_state_dict = None - consumed_samples = 0 - - init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False)) - - train_ds, validation_ds = text_webdataset.build_train_valid_datasets( - cfg.model.data, consumed_samples=consumed_samples - ) - validation_ds = [d["captions"] for d in list(validation_ds)] - - val_dataloader = build_dataloader( - cfg, - dataset=validation_ds, - consumed_samples=consumed_samples, - mbs=cfg.model.micro_batch_size, - gbs=cfg.model.global_batch_size, - load_gbs=True, - ) - - init_using_ptl(trainer, ptl_model, val_dataloader, validation_ds) - - optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) - - ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) - - logger.log_hyperparams(OmegaConf.to_container(cfg)) - - reward_model = get_reward_model(cfg.rm, mbs=cfg.model.micro_batch_size, gbs=cfg.model.global_batch_size) - ptl_model.reward_model = reward_model - - ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) - timer = Timer(cfg.exp_manager.get("max_time_per_run", "0:12:00:00")) - - draft_p_trainer = SupervisedTrainer( - cfg=cfg.trainer.draftp_sd, - model=ptl_model, - optimizer=optimizer, - scheduler=scheduler, - train_dataloader=val_dataloader, - val_dataloader=val_dataloader, - test_dataloader=[], - logger=logger, - ckpt_callback=ckpt_callback, - run_timer=timer, - ) - - if custom_trainer_state_dict is not None: - draft_p_trainer.load_state_dict(custom_trainer_state_dict) - - # Run annealed guidance - if cfg.get("prompt") is not None: - logging.info(f"Override val dataset with custom prompt: {cfg.prompt}") - val_dataloader = [[cfg.prompt]] - - wt_types = cfg.get("weight_type", None) - if wt_types is None: - wt_types = ["base", "draft", "linear", "power_2", "power_4", "step_0.6"] - else: - wt_types = wt_types.split(",") if isinstance(wt_types, str) else wt_types - logging.info(f"Running on types: {wt_types}") - - # run for all weight types - for wt_type in wt_types: - global_idx = 0 - if wt_type is None or wt_type == "base": - # dummy function that assigns a value of 0 all the time - logging.info("using the base model") - wt_draft = lambda sigma, sigma_next, i, total: 0 - else: - if wt_type == "linear": - wt_draft = lambda sigma, sigma_next, i, total: i * 1.0 / total - elif wt_type == "draft": - wt_draft = lambda sigma, sigma_next, i, total: 1 - elif wt_type.startswith("power"): # its of the form power_{power} - pow = float(wt_type.split("_")[1]) - wt_draft = lambda sigma, sigma_next, i, total: (i * 1.0 / total) ** pow - elif wt_type.startswith("step"): # use a step function (step_{p}) - frac = float(wt_type.split("_")[1]) - wt_draft = lambda sigma, sigma_next, i, total: float((i * 1.0 / total) >= frac) - else: - raise ValueError(f"invalid weighing type: {wt_type}") - logging.info(f"using weighing type for annealed outputs: {wt_type}.") - - # initialize generator - gen = torch.Generator(device="cpu") - gen.manual_seed((1243 + 1247837 * local_rank) % (int(2 ** 32 - 1))) - os.makedirs(f"./annealed_outputs_sd_{wt_type}/", exist_ok=True) - - for batch in val_dataloader: - batch_size = len(batch) - with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()): - latents = torch.randn( - [ - batch_size, - ptl_model.in_channels, - ptl_model.height // ptl_model.downsampling_factor, - ptl_model.width // ptl_model.downsampling_factor, - ], - generator=gen, - ).to(torch.cuda.current_device()) - images = ptl_model.annealed_guidance(batch, latents, weighing_fn=wt_draft) - images = ( - images.permute(0, 2, 3, 1).detach().float().cpu().numpy().astype(np.uint8) - ) # outputs are already scaled from [0, 255] - # save to pil - for i in range(images.shape[0]): - i = i + global_idx - img_path = f"annealed_outputs_sd_{wt_type}/img_{i:05d}_{local_rank:02d}.png" - prompt_path = f"annealed_outputs_sd_{wt_type}/prompt_{i:05d}_{local_rank:02d}.txt" - Image.fromarray(images[i]).save(img_path) - with open(prompt_path, "w") as fi: - fi.write(batch[i]) - # increment global index - global_idx += batch_size - logging.info("Saved all images.") - - -if __name__ == "__main__": - main() diff --git a/examples/mm/stable_diffusion/anneal_sdxl.py b/examples/mm/stable_diffusion/anneal_sdxl.py index 515479ace..678b799c2 100644 --- a/examples/mm/stable_diffusion/anneal_sdxl.py +++ b/examples/mm/stable_diffusion/anneal_sdxl.py @@ -228,7 +228,7 @@ def checkpoint_check_fn(module): torch.distributed.barrier() ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) - timer = Timer(cfg.exp_manager.get("max_time_per_run", "0:03:55:00")) # save a model just before 4 hours + timer = Timer(cfg.exp_manager.get("max_time_per_run", None) if cfg.exp_manager else None) draft_p_trainer = SupervisedTrainer( cfg=cfg.trainer.draftp_sd, @@ -263,29 +263,29 @@ def checkpoint_check_fn(module): # run for all weight types for wt_type in wt_types: global_idx = 0 - if wt_type is None or wt_type == "base": + if wt_type == "base": # dummy function that assigns a value of 0 all the time logging.info("using the base model") wt_draft = lambda sigma, sigma_next, i, total: 0 + elif wt_type == "linear": + wt_draft = lambda sigma, sigma_next, i, total: i * 1.0 / total + elif wt_type == "draft": + wt_draft = lambda sigma, sigma_next, i, total: 1 + elif wt_type.startswith("power"): # its of the form power_{power} + pow = float(wt_type.split("_")[1]) + wt_draft = lambda sigma, sigma_next, i, total: (i * 1.0 / total) ** pow + elif wt_type.startswith("step"): # use a step function (step_{p}) + frac = float(wt_type.split("_")[1]) + wt_draft = lambda sigma, sigma_next, i, total: float((i * 1.0 / total) >= frac) else: - if wt_type == "linear": - wt_draft = lambda sigma, sigma_next, i, total: i * 1.0 / total - elif wt_type == "draft": - wt_draft = lambda sigma, sigma_next, i, total: 1 - elif wt_type.startswith("power"): # its of the form power_{power} - pow = float(wt_type.split("_")[1]) - wt_draft = lambda sigma, sigma_next, i, total: (i * 1.0 / total) ** pow - elif wt_type.startswith("step"): # use a step function (step_{p}) - frac = float(wt_type.split("_")[1]) - wt_draft = lambda sigma, sigma_next, i, total: float((i * 1.0 / total) >= frac) - else: - raise ValueError(f"invalid weighing type: {wt_type}") - logging.info(f"using weighing type for annealed outputs: {wt_type}.") + raise ValueError(f"invalid weighing type: {wt_type}") + logging.info(f"using weighing type for annealed outputs: {wt_type}.") # initialize generator + exp_dir = cfg.exp_manager.explicit_log_dir gen = torch.Generator(device="cpu") gen.manual_seed((1243 + 1247837 * local_rank) % (int(2 ** 32 - 1))) - os.makedirs(f"./annealed_outputs_sdxl_{wt_type}/", exist_ok=True) + os.makedirs(os.path.join(exp_dir, f"annealed_outputs_sdxl_{wt_type}/"), exist_ok=True) for batch in val_dataloader: batch_size = len(batch) @@ -306,8 +306,8 @@ def checkpoint_check_fn(module): # save to pil for i in range(images.shape[0]): i = i + global_idx - img_path = f"annealed_outputs_sdxl_{wt_type}/img_{i:05d}_{local_rank:02d}.png" - prompt_path = f"annealed_outputs_sdxl_{wt_type}/prompt_{i:05d}_{local_rank:02d}.txt" + img_path = os.path.join(exp_dir, f"annealed_outputs_sdxl_{wt_type}/img_{i:05d}_{local_rank:02d}.png") + prompt_path = os.path.join(exp_dir, f"annealed_outputs_sdxl_{wt_type}/prompt_{i:05d}_{local_rank:02d}.txt") Image.fromarray(images[i]).save(img_path) with open(prompt_path, "w") as fi: fi.write(batch[i]) From 9c18398bb74024581dc93be3e76d69747aeaf123 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Sep 2024 18:29:39 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/mm/stable_diffusion/anneal_sdxl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/mm/stable_diffusion/anneal_sdxl.py b/examples/mm/stable_diffusion/anneal_sdxl.py index ab243e31c..12a95b3ca 100644 --- a/examples/mm/stable_diffusion/anneal_sdxl.py +++ b/examples/mm/stable_diffusion/anneal_sdxl.py @@ -229,7 +229,7 @@ def checkpoint_check_fn(module): torch.distributed.barrier() ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) - timer = Timer(cfg.exp_manager.get("max_time_per_run", None) if cfg.exp_manager else None) + timer = Timer(cfg.exp_manager.get("max_time_per_run", None) if cfg.exp_manager else None) draft_p_trainer = SupervisedTrainer( cfg=cfg.trainer.draftp_sd, @@ -308,7 +308,9 @@ def checkpoint_check_fn(module): for i in range(images.shape[0]): i = i + global_idx img_path = os.path.join(exp_dir, f"annealed_outputs_sdxl_{wt_type}/img_{i:05d}_{local_rank:02d}.png") - prompt_path = os.path.join(exp_dir, f"annealed_outputs_sdxl_{wt_type}/prompt_{i:05d}_{local_rank:02d}.txt") + prompt_path = os.path.join( + exp_dir, f"annealed_outputs_sdxl_{wt_type}/prompt_{i:05d}_{local_rank:02d}.txt" + ) Image.fromarray(images[i]).save(img_path) with open(prompt_path, "w") as fi: fi.write(batch[i]) From db7c5fbb6b7c7353e69b34f2cde9959ed229fb9f Mon Sep 17 00:00:00 2001 From: Rohit Jena Date: Fri, 6 Sep 2024 13:18:59 -0700 Subject: [PATCH 8/8] changed annealed_out_dir and added todo to update megatrontrainerbuilder functionality Signed-off-by: Rohit Jena --- examples/mm/stable_diffusion/anneal_sdxl.py | 11 +++++++---- examples/mm/stable_diffusion/train_sdxl_draftp.py | 2 ++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/mm/stable_diffusion/anneal_sdxl.py b/examples/mm/stable_diffusion/anneal_sdxl.py index ab243e31c..e87da4a73 100644 --- a/examples/mm/stable_diffusion/anneal_sdxl.py +++ b/examples/mm/stable_diffusion/anneal_sdxl.py @@ -94,6 +94,8 @@ mp.set_start_method("spawn", force=True) +# TODO: this functionality should go into NeMo +# Specifically, the NeMo MegatronTrainerBuilder must also accept extra FSDP wrap modules so that it doesnt need to be subclassed class MegatronStableDiffusionTrainerBuilder(MegatronTrainerBuilder): """Builder for SD model Trainer with overrides.""" @@ -283,10 +285,11 @@ def checkpoint_check_fn(module): logging.info(f"using weighing type for annealed outputs: {wt_type}.") # initialize generator - exp_dir = cfg.exp_manager.explicit_log_dir + annealed_out_dir = os.path.join(cfg.exp_manager.explicit_log_dir, f"annealed_outputs_sdxl_{wt_type}/") + # generate random seed for reproducibility and make output dir gen = torch.Generator(device="cpu") gen.manual_seed((1243 + 1247837 * local_rank) % (int(2 ** 32 - 1))) - os.makedirs(os.path.join(exp_dir, f"annealed_outputs_sdxl_{wt_type}/"), exist_ok=True) + os.makedirs(annealed_out_dir, exist_ok=True) for batch in val_dataloader: batch_size = len(batch) @@ -307,8 +310,8 @@ def checkpoint_check_fn(module): # save to pil for i in range(images.shape[0]): i = i + global_idx - img_path = os.path.join(exp_dir, f"annealed_outputs_sdxl_{wt_type}/img_{i:05d}_{local_rank:02d}.png") - prompt_path = os.path.join(exp_dir, f"annealed_outputs_sdxl_{wt_type}/prompt_{i:05d}_{local_rank:02d}.txt") + img_path = os.path.join(annealed_out_dir, f"img_{i:05d}_{local_rank:02d}.png") + prompt_path = os.path.join(annealed_out_dir, f"prompt_{i:05d}_{local_rank:02d}.txt") Image.fromarray(images[i]).save(img_path) with open(prompt_path, "w") as fi: fi.write(batch[i]) diff --git a/examples/mm/stable_diffusion/train_sdxl_draftp.py b/examples/mm/stable_diffusion/train_sdxl_draftp.py index cc56c2b59..8e7d36a1a 100644 --- a/examples/mm/stable_diffusion/train_sdxl_draftp.py +++ b/examples/mm/stable_diffusion/train_sdxl_draftp.py @@ -91,6 +91,8 @@ mp.set_start_method("spawn", force=True) +# TODO: this functionality should go into NeMo +# Specifically, the NeMo MegatronTrainerBuilder must also accept extra FSDP wrap modules so that it doesnt need to be subclassed class MegatronStableDiffusionTrainerBuilder(MegatronTrainerBuilder): """Builder for SD model Trainer with overrides."""