diff --git a/deepmd/main.py b/deepmd/main.py index 492f3b085e..65d417b6b4 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -854,6 +854,47 @@ def main_parser() -> argparse.ArgumentParser: choices=["model-branch", "type-map", "descriptor", "fitting-net", "size"], nargs="+", ) + # grad-probe: per-task descriptor gradient probe + parser_grad_probe = subparsers.add_parser( + "grad-probe", + parents=[parser_log, parser_mpi_log], + help="Compute per-task descriptor gradient vectors from a multitask checkpoint.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser_grad_probe.add_argument("INPUT", help="Training config JSON file.") + parser_grad_probe.add_argument( + "--ckpt", required=True, help="Checkpoint .pt file path." + ) + parser_grad_probe.add_argument( + "-o", + "--output", + default="descriptor_grads.npz", + help="Output NPZ file path.", + ) + parser_grad_probe.add_argument( + "-n", + "--nbatches", + type=int, + default=1, + help="Number of batches per task to average gradient over.", + ) + parser_grad_probe.add_argument( + "-k", + "--accumulate-k", + type=int, + default=1, + dest="accumulate_k", + help="Group size for gradient accumulation before computing each dot product. " + "Total batches collected = nbatches; number of dot product samples = nbatches // k. " + "Larger K reduces norm variance and improves SNR of mean/std.", + ) + parser_grad_probe.add_argument( + "--pcgrad", + action="store_true", + default=False, + help="Apply PCGrad projection to descriptor gradients before analysis. " + "Reports similarity of the projected vectors instead of raw gradients.", + ) return parser @@ -919,6 +960,7 @@ def main(args: Optional[list[str]] = None) -> None: "convert-from", "train-nvnmd", "change-bias", + "grad-probe", ): deepmd_main = BACKENDS[args.backend]().entry_point_hook elif args.command is None: diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 0e248583ec..8cf19c8697 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -513,6 +513,246 @@ def change_bias( log.info(f"Saved model to {output_path}") +def grad_probe(FLAGS) -> None: + """Compute per-task descriptor gradient vectors for multitask conflict analysis.""" + import numpy as np + + from deepmd.common import j_loader + from deepmd.pt.utils.multi_task import preprocess_shared_params + from deepmd.utils.argcheck import normalize + from deepmd.utils.compat import update_deepmd_input + + config = j_loader(FLAGS.INPUT) + config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") + multi_task = "model_dict" in config.get("model", {}) + shared_links = None + if multi_task: + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = normalize(config, multi_task=multi_task) + + trainer = get_trainer( + config, + restart_model=FLAGS.ckpt, + shared_links=shared_links, + ) + trainer.wrapper.eval() + + module = ( + trainer.wrapper.module + if hasattr(trainer.wrapper, "module") + else trainer.wrapper + ) + + grads_per_task: dict = {} + param_names: list | None = None + param_shapes: list | None = None + + accumulate_k = getattr(FLAGS, "accumulate_k", 1) + total_batches = FLAGS.nbatches + use_pcgrad = getattr(FLAGS, "pcgrad", False) + if use_pcgrad: + log.info("grad_probe: PCGrad projection enabled — analysing projected descriptor gradients.") + + # Initialize trackers + task_batch_norms = {k: [] for k in trainer.model_keys} + task_accum_grads = {k: None for k in trainer.model_keys} + pairwise_dots = { + (k1, k2): [] + for i, k1 in enumerate(trainer.model_keys) + for j, k2 in enumerate(trainer.model_keys) + if i < j + } + + # Accumulators within each group of K batches + group_grads = {k: None for k in trainer.model_keys} + + cur_lr = config["learning_rate"]["start_lr"] + + for b in range(total_batches): + # Collect raw descriptor grad vectors for all tasks this batch + batch_grad_vecs = {} + for task_key in trainer.model_keys: + trainer.optimizer.zero_grad(set_to_none=True) + input_dict, label_dict, _ = trainer.get_data( + is_train=True, task_key=task_key + ) + _, loss, _ = trainer.wrapper( + **input_dict, cur_lr=cur_lr, label=label_dict, task_key=task_key + ) + loss.backward() + + model = module.model[task_key] + descriptor = model.get_descriptor() + + current_grads = [] + for name, param in descriptor.named_parameters(): + if param.grad is not None: + current_grads.append( + param.grad.detach().cpu().float().numpy().ravel() + ) + + if current_grads: + batch_grad_vecs[task_key] = np.concatenate(current_grads) + + if param_names is None: + param_names = [ + n for n, p in descriptor.named_parameters() if p.requires_grad + ] + param_shapes = [ + list(p.shape) + for n, p in descriptor.named_parameters() + if p.requires_grad + ] + + # Apply PCGrad projection to descriptor grad vectors if requested + if use_pcgrad and len(batch_grad_vecs) == 2: + keys = list(batch_grad_vecs.keys()) + k0, k1 = keys[0], keys[1] + g0, g1 = batch_grad_vecs[k0], batch_grad_vecs[k1] + dot = float(np.dot(g0, g1)) + if dot < 0: + g0_orig = g0.copy() + batch_grad_vecs[k0] = g0 - dot / max(np.dot(g1, g1), 1e-12) * g1 + batch_grad_vecs[k1] = g1 - dot / max(np.dot(g0_orig, g0_orig), 1e-12) * g0_orig + + # Accumulate (projected) grad vectors + for task_key, grad_vec in batch_grad_vecs.items(): + if task_accum_grads[task_key] is None: + task_accum_grads[task_key] = grad_vec.copy() + else: + task_accum_grads[task_key] += grad_vec + + if group_grads[task_key] is None: + group_grads[task_key] = grad_vec.copy() + else: + group_grads[task_key] += grad_vec + + # At the end of each group of K batches, record norms and dot products then reset + if (b + 1) % accumulate_k == 0: + # Normalize to per-batch mean so scale is independent of k + group_means = { + k: g / accumulate_k if g is not None else None + for k, g in group_grads.items() + } + for task_key in trainer.model_keys: + if group_means[task_key] is not None: + task_batch_norms[task_key].append( + np.linalg.norm(group_means[task_key]) + ) + for k1, k2 in pairwise_dots.keys(): + g1 = group_means.get(k1) + g2 = group_means.get(k2) + if g1 is not None and g2 is not None: + pairwise_dots[(k1, k2)].append(float(np.dot(g1, g2))) + group_grads = {k: None for k in trainer.model_keys} + + # Compute final statistics and log for each task + for task_key in trainer.model_keys: + accumulated_grads = task_accum_grads[task_key] + norms = task_batch_norms[task_key] + if accumulated_grads is not None: + avg_grad_vec = accumulated_grads / total_batches + norm_mean = np.mean(norms) + norm_std = np.std(norms) + else: + avg_grad_vec = np.array([], dtype=np.float32) + norm_mean = 0.0 + norm_std = 0.0 + + G_sq_norm = float(np.dot(avg_grad_vec, avg_grad_vec)) + norms_arr = np.array(norms) + mean_sq_norm = float(np.mean(norms_arr ** 2)) + tr_sigma = mean_sq_norm - G_sq_norm + noise_scale = tr_sigma / G_sq_norm if G_sq_norm > 1e-30 else float("nan") + + grads_per_task[task_key] = avg_grad_vec + grads_per_task[f"{task_key}_norm_mean"] = norm_mean + grads_per_task[f"{task_key}_norm_std"] = norm_std + grads_per_task[f"{task_key}_noise_scale"] = noise_scale + + log.info( + "Task '%s': collected %d batches (groups=%d, k=%d). " + "Avg Grad Norm=%.4e, Mean(Group Norms)=%.4e, Std(Group Norms)=%.4e, " + "Noise Scale(group)=%.4e", + task_key, + total_batches, + total_batches // accumulate_k, + accumulate_k, + float(np.linalg.norm(avg_grad_vec)), + float(norm_mean), + float(norm_std), + noise_scale, + ) + + # Compute pairwise dot product and cosine similarity statistics + for (k1, k2), dots in pairwise_dots.items(): + if dots: + dot_mean = float(np.mean(dots)) + dot_std = float(np.std(dots)) + + g1_avg = grads_per_task[k1] + g2_avg = grads_per_task[k2] + dot_global = float(np.dot(g1_avg, g2_avg)) + + # Group-level cosine: per-group cos_i = dot_i / (||g1_i|| * ||g2_i||) + norms_1 = np.array(task_batch_norms[k1]) + norms_2 = np.array(task_batch_norms[k2]) + norm_products = norms_1 * norms_2 + valid = norm_products > 1e-12 + cos_per_group = np.where( + valid, + np.array(dots) / np.where(valid, norm_products, 1.0), + float("nan"), + ) + valid_cos = cos_per_group[valid] + if len(valid_cos) > 0: + cos_group_mean = float(np.mean(valid_cos)) + cos_group_std = float(np.std(valid_cos)) + else: + cos_group_mean = float("nan") + cos_group_std = float("nan") + + # Global-level cosine: cosine of accumulated average gradient vectors + norm_g1 = float(np.linalg.norm(g1_avg)) + norm_g2 = float(np.linalg.norm(g2_avg)) + denom_global = norm_g1 * norm_g2 + cos_global = ( + float(dot_global / denom_global) + if denom_global > 1e-12 + else float("nan") + ) + + grads_per_task[f"dot_{k1}_{k2}_mean"] = dot_mean + grads_per_task[f"dot_{k1}_{k2}_std"] = dot_std + grads_per_task[f"dot_{k1}_{k2}_global"] = dot_global + grads_per_task[f"cos_group_mean_{k1}_{k2}"] = cos_group_mean + grads_per_task[f"cos_group_std_{k1}_{k2}"] = cos_group_std + grads_per_task[f"cos_global_{k1}_{k2}"] = cos_global + + log.info( + "Grad similarity '%s' vs '%s': " + "dot_mean=%.4e, dot_std=%.4e, dot_global=%.4e, " + "cos_group_mean=%.4f, cos_group_std=%.4f, cos_global=%.4f " + "(n_groups=%d)", + k1, + k2, + dot_mean, + dot_std, + dot_global, + cos_group_mean, + cos_group_std, + cos_global, + len(dots), + ) + + save_dict = {f"grads_{k}": v for k, v in grads_per_task.items()} + save_dict["param_names"] = np.array(param_names, dtype=object) + save_dict["param_shapes"] = np.array(param_shapes, dtype=object) + np.savez(FLAGS.output, **save_dict) + log.info("Descriptor gradient vectors and similarity stats saved to: %s", FLAGS.output) + + + @record def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None: if not isinstance(args, argparse.Namespace): @@ -572,6 +812,8 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None: check_frequency=FLAGS.frequency, training_script=FLAGS.training_script, ) + elif FLAGS.command == "grad-probe": + grad_probe(FLAGS) else: raise RuntimeError(f"Invalid command {FLAGS.command}!") diff --git a/deepmd/pt/loss/ener.py b/deepmd/pt/loss/ener.py index 00a352424e..e3df7dfbe2 100644 --- a/deepmd/pt/loss/ener.py +++ b/deepmd/pt/loss/ener.py @@ -151,13 +151,7 @@ def __init__( assert self.use_huber or self.use_l1_all, ( "f_use_norm can only be True when use_huber or use_l1_all is True." ) - if self.use_huber and ( - self.has_pf or self.has_gf or self.relative_f is not None - ): - raise RuntimeError( - "Huber loss is not implemented for force with atom_pref, generalized force and relative force. " - ) - + def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): """Return loss on energy and force. @@ -220,7 +214,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): l2_ener_loss.detach(), find_energy ) if not self.use_huber: - loss += atom_norm * (pref_e * l2_ener_loss) + loss += atom_norm**2 * (pref_e * l2_ener_loss) else: l_huber_loss = custom_huber_loss( atom_norm * model_pred["energy"], @@ -346,8 +340,8 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): ) else: l_huber_loss = custom_huber_loss( - (atom_pref * force_pred).reshape(-1), - (atom_pref * force_label).reshape(-1), + atom_pref_reshape * force_pred.reshape(-1), + atom_pref_reshape * force_label.reshape(-1), delta=self.huber_delta, ) loss += pref_pf * l_huber_loss @@ -405,7 +399,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): l2_virial_loss.detach(), find_virial ) if not self.use_huber: - loss += atom_norm * (pref_v * l2_virial_loss) + loss += atom_norm**2 * (pref_v * l2_virial_loss) else: l_huber_loss = custom_huber_loss( atom_norm * model_pred["virial"].reshape(-1), diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 97b1e29da3..82944442cc 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -656,6 +656,8 @@ def forward( nlist: torch.Tensor, mapping: Optional[torch.Tensor] = None, comm_dict: Optional[dict[str, torch.Tensor]] = None, + fparam: Optional[torch.Tensor] = None, + case_embd: Optional[torch.Tensor] = None, ): """Compute the descriptor. @@ -671,6 +673,10 @@ def forward( The index mapping, not required by this descriptor. comm_dict The data needed for communication for parallel inference. + fparam + The frame-level parameters. shape: nf x nfparam + case_embd + The case (dataset) embedding for multitask training with shared fitting. shape: nf x dim_case_embd Returns ------- diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 7e0761915f..a3660473c2 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -559,12 +559,13 @@ def collect_single_finetune_params( for kk in rm_list: state_dict.pop(kk) state_dict["_extra_state"]["model_params"] = old_model_params - out_shape_list = [ - "model.Default.atomic_model.out_bias", - "model.Default.atomic_model.out_std", - ] + out_shape_list = [] + for model_key in self.model_keys: + out_shape_list.append(f"model.{model_key}.atomic_model.out_bias") + out_shape_list.append(f"model.{model_key}.atomic_model.out_std") for kk in out_shape_list: - state_dict[kk] = state_dict[kk][:1, :, :1] + if kk in state_dict: + state_dict[kk] = state_dict[kk][:1, :, :1] self.wrapper.load_state_dict(state_dict) # change bias for fine-tuning @@ -625,6 +626,47 @@ def single_model_finetune( assert sum_prob > 0.0, "Sum of model prob must be larger than 0!" self.model_prob = self.model_prob / sum_prob + self.use_pcgrad = self.multi_task and training_params.get("use_pcgrad", False) + self.use_dual_batch = self.multi_task and ( + self.use_pcgrad or training_params.get("use_dual_batch", False) + ) + self.use_alternating = self.multi_task and ( + not self.use_dual_batch + and training_params.get("alternating_tasks", False) + ) + self.use_grad_norm_reweight = self.use_dual_batch and training_params.get( + "grad_norm_reweight", False + ) + self.use_loss_ratio_reweight = self.use_dual_batch and training_params.get( + "loss_ratio_reweight", False + ) + self.reweight_ema_decay = training_params.get("reweight_ema_decay", 0.99) + if self.use_grad_norm_reweight or self.use_loss_ratio_reweight: + # Per-component EMA norms: descriptor and shared fitting_net tracked separately + self.grad_norm_ema_desc = {k: None for k in self.model_keys} + self.grad_norm_ema_fit = {k: None for k in self.model_keys} + # Two-speed EMA for loss: relative rate = fast/slow, scale-invariant + self.loss_val_ema_fast = {k: None for k in self.model_keys} + self.loss_val_ema_slow = {k: None for k in self.model_keys} + if self.use_pcgrad: + log.info("PCGrad enabled: descriptor gradients will be projected each step.") + elif self.use_dual_batch: + reweight_modes = [] + if self.use_grad_norm_reweight: + reweight_modes.append("grad-norm-EMA") + if self.use_loss_ratio_reweight: + reweight_modes.append("loss-ratio") + if reweight_modes: + log.info( + "Dual-batch enabled with reweighting: %s (ema_decay=%.3f).", + "+".join(reweight_modes), + self.reweight_ema_decay, + ) + else: + log.info("Dual-batch enabled: all tasks sampled per step, gradients averaged.") + elif self.use_alternating: + log.info("Alternating-tasks enabled: tasks cycled deterministically A→B→A→B each step.") + # Multi-task share params if shared_links is not None: _data_stat_protect = np.array( @@ -735,11 +777,14 @@ def run(self) -> None: def step(_step_id, task_key="Default") -> None: if self.multi_task: - model_index = dp_random.choice( - np.arange(self.num_model, dtype=np.int_), - p=self.model_prob, - ) - task_key = self.model_keys[model_index] + if self.use_alternating: + task_key = self.model_keys[_step_id % self.num_model] + else: + model_index = dp_random.choice( + np.arange(self.num_model, dtype=np.int_), + p=self.model_prob, + ) + task_key = self.model_keys[model_index] # PyTorch Profiler if self.enable_profiler or self.profiling: prof.step() @@ -750,23 +795,194 @@ def step(_step_id, task_key="Default") -> None: cur_lr = _lr.value(_step_id) pref_lr = cur_lr self.optimizer.zero_grad(set_to_none=True) - input_dict, label_dict, log_dict = self.get_data( - is_train=True, task_key=task_key - ) - if SAMPLER_RECORD: - print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" - fout1.write(print_str) - fout1.flush() + if not self.use_dual_batch: + input_dict, label_dict, log_dict = self.get_data( + is_train=True, task_key=task_key + ) + if SAMPLER_RECORD: + print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" + fout1.write(print_str) + fout1.flush() if self.opt_type in ["Adam", "AdamW"]: cur_lr = self.scheduler.get_last_lr()[0] if _step_id < self.warmup_steps: pref_lr = _lr.start_lr else: pref_lr = cur_lr - model_pred, loss, more_loss = self.wrapper( - **input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key - ) - loss.backward() + if self.use_dual_batch: + all_params = list(self.wrapper.parameters()) + task_grads = {} + more_loss_per_task = {} + task_loss_val = {} + for tk in self.model_keys: + self.optimizer.zero_grad(set_to_none=True) + in_d, lbl_d, log_d = self.get_data(is_train=True, task_key=tk) + if SAMPLER_RECORD and tk == self.model_keys[0]: + print_str = f"Step {_step_id}: sample system{log_d['sid']} frame{log_d['fid']}\n" + fout1.write(print_str) + fout1.flush() + _, loss, more_loss = self.wrapper( + **in_d, cur_lr=pref_lr, label=lbl_d, task_key=tk + ) + loss.backward() + task_grads[tk] = { + id(p): p.grad.clone() if p.grad is not None else None + for p in all_params + } + more_loss_per_task[tk] = more_loss + task_loss_val[tk] = loss.item() + + k0, k1 = self.model_keys[0], self.model_keys[1] + + if self.use_pcgrad: + module = ( + self.wrapper.module + if hasattr(self.wrapper, "module") + else self.wrapper + ) + descriptor = module.model[k0].get_descriptor() + desc_param_ids = {id(p) for p in descriptor.parameters()} + desc_pids = [ + pid for pid in desc_param_ids + if task_grads[k0].get(pid) is not None + and task_grads[k1].get(pid) is not None + ] + if desc_pids: + g0 = torch.cat([task_grads[k0][pid].flatten() for pid in desc_pids]) + g1 = torch.cat([task_grads[k1][pid].flatten() for pid in desc_pids]) + dot = (g0 * g1).sum() + projected = dot < 0 + if projected: + g0_proj = g0 - dot / (g1 * g1).sum().clamp(min=1e-12) * g1 + g1_proj = g1 - dot / (g0 * g0).sum().clamp(min=1e-12) * g0 + offset = 0 + for pid in desc_pids: + numel = task_grads[k0][pid].numel() + shape = task_grads[k0][pid].shape + task_grads[k0][pid] = g0_proj[offset:offset + numel].reshape(shape) + task_grads[k1][pid] = g1_proj[offset:offset + numel].reshape(shape) + offset += numel + if self.rank == 0 and (_step_id + 1) % self.disp_freq == 0: + cos_sim = dot / (g0.norm() * g1.norm()).clamp(min=1e-12) + log.info( + "PCGrad step %d: desc dot=%.4e, cos_sim=%.4f, projected=%s", + _step_id + 1, dot.item(), cos_sim.item(), projected, + ) + + # Gradient combination: per-component reweighting when active, + # otherwise simple equal average for shared params + self.optimizer.zero_grad(set_to_none=True) + num_tasks = len(self.model_keys) + + if self.use_grad_norm_reweight or self.use_loss_ratio_reweight: + d = self.reweight_ema_decay + d_slow = 1.0 - (1.0 - d) / 10.0 + + _module = ( + self.wrapper.module + if hasattr(self.wrapper, "module") + else self.wrapper + ) + _desc_param_ids = { + id(p) for p in _module.model[k0].get_descriptor().parameters() + } + _shared_pids = { + id(p) for p in all_params + if task_grads[k0].get(id(p)) is not None + and task_grads[k1].get(id(p)) is not None + } + _fit_shared_pids = _shared_pids - _desc_param_ids + + if self.use_grad_norm_reweight: + for k in self.model_keys: + desc_g = [ + task_grads[k][pid] + for pid in (_shared_pids & _desc_param_ids) + if task_grads[k].get(pid) is not None + ] + if desc_g: + cur = torch.stack([g.norm() for g in desc_g]).norm().item() + if self.grad_norm_ema_desc[k] is None: + self.grad_norm_ema_desc[k] = cur + else: + self.grad_norm_ema_desc[k] = d * self.grad_norm_ema_desc[k] + (1 - d) * cur + fit_g = [ + task_grads[k][pid] + for pid in _fit_shared_pids + if task_grads[k].get(pid) is not None + ] + if fit_g: + cur = torch.stack([g.norm() for g in fit_g]).norm().item() + if self.grad_norm_ema_fit[k] is None: + self.grad_norm_ema_fit[k] = cur + else: + self.grad_norm_ema_fit[k] = d * self.grad_norm_ema_fit[k] + (1 - d) * cur + + if self.use_loss_ratio_reweight: + for k in self.model_keys: + v = task_loss_val[k] + if self.loss_val_ema_fast[k] is None: + self.loss_val_ema_fast[k] = v + self.loss_val_ema_slow[k] = v + else: + self.loss_val_ema_fast[k] = d * self.loss_val_ema_fast[k] + (1 - d) * v + self.loss_val_ema_slow[k] = d_slow * self.loss_val_ema_slow[k] + (1 - d_slow) * v + + # Build per-component normalized weights + def _w(norm_ema): + return {k: 1.0 / (norm_ema[k] + 1e-8) if norm_ema[k] is not None else 1.0 + for k in self.model_keys} + + dw = _w(self.grad_norm_ema_desc) if self.use_grad_norm_reweight else {k: 1.0 for k in self.model_keys} + fw = _w(self.grad_norm_ema_fit) if self.use_grad_norm_reweight else {k: 1.0 for k in self.model_keys} + + if self.use_loss_ratio_reweight: + for k in self.model_keys: + rel = (self.loss_val_ema_fast[k] / (self.loss_val_ema_slow[k] + 1e-8) + if self.loss_val_ema_fast[k] is not None else 1.0) + dw[k] *= rel + fw[k] *= rel + + dw_sum = sum(dw.values()) + fw_sum = sum(fw.values()) + desc_w = {k: dw[k] / dw_sum for k in self.model_keys} + fit_w = {k: fw[k] / fw_sum for k in self.model_keys} + + if self.rank == 0 and (_step_id + 1) % self.disp_freq == 0: + log.info( + "Reweight step %d: desc=[%s] fit=[%s]", + _step_id + 1, + ", ".join(f"{k}={desc_w[k]:.4f}" for k in self.model_keys), + ", ".join(f"{k}={fit_w[k]:.4f}" for k in self.model_keys), + ) + + for p in all_params: + pid = id(p) + g0p, g1p = task_grads[k0][pid], task_grads[k1][pid] + if g0p is not None and g1p is not None: + w0, w1 = (desc_w[k0], desc_w[k1]) if pid in _desc_param_ids else (fit_w[k0], fit_w[k1]) + p.grad = w0 * g0p + w1 * g1p + elif g0p is not None: + p.grad = g0p + elif g1p is not None: + p.grad = g1p + else: + for p in all_params: + pid = id(p) + g0p, g1p = task_grads[k0][pid], task_grads[k1][pid] + if g0p is not None and g1p is not None: + p.grad = (g0p + g1p) / num_tasks + elif g0p is not None: + p.grad = g0p + elif g1p is not None: + p.grad = g1p + + more_loss = more_loss_per_task[task_key] + else: + _, loss, more_loss = self.wrapper( + **input_dict, cur_lr=pref_lr, label=label_dict, task_key=task_key + ) + loss.backward() if self.gradient_max_norm > 0.0: torch.nn.utils.clip_grad_norm_( self.wrapper.parameters(), diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index bc771b41d4..716dbc98c5 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -214,6 +214,7 @@ def print_summary( name: str, prob: list[float], ) -> None: + return rank = dist.get_rank() if dist.is_initialized() else 0 if rank == 0: print_summary( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index bbb0c01a4c..bf7ea5944f 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3386,6 +3386,12 @@ def training_args( if not multi_task else [ Argument("model_prob", dict, optional=True, default={}, doc=doc_model_prob), + Argument("use_pcgrad", bool, optional=True, default=False, doc="Apply PCGrad gradient surgery on the shared descriptor parameters in multi-task training."), + Argument("use_dual_batch", bool, optional=True, default=False, doc="Sample all tasks every step and sum gradients without projection. Use as control group to isolate PCGrad effect from dual-batch effect."), + Argument("alternating_tasks", bool, optional=True, default=False, doc="Cycle through tasks deterministically (A→B→A→B) each step instead of random sampling. Ablation control to isolate balanced-sampling effect from combined-gradient effect."), + Argument("grad_norm_reweight", bool, optional=True, default=False, doc="(dual-batch only) Reweight per-task gradients inversely proportional to their EMA gradient norm before combining, equalizing each task's directional contribution to shared parameters."), + Argument("loss_ratio_reweight", bool, optional=True, default=False, doc="(dual-batch only) Reweight per-task gradients proportional to their EMA loss value, giving more weight to the higher-loss task to prevent it from being sacrificed."), + Argument("reweight_ema_decay", float, optional=True, default=0.99, doc="EMA decay factor for grad_norm_reweight and loss_ratio_reweight tracking. Higher values give smoother but slower-adapting estimates."), Argument("data_dict", dict, data_args, repeat=True, doc=doc_data_dict), ] )