-
Notifications
You must be signed in to change notification settings - Fork 1.4k
[megatron] support gemma4 megatron #9296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
152593c
5086cdb
5c9f86b
e732b71
f37f967
d7c2f28
96ff166
fd783fc
c8aa1c1
845ad34
72a258d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| # 8 * 80GiB | ||
| # Due to the use of group_by_length, the data is not sufficiently shuffled, | ||
| # which may cause fluctuations in the loss curve. Please adjust the parameters accordingly. | ||
| PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ | ||
| NPROC_PER_NODE=8 \ | ||
| CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ | ||
| megatron sft \ | ||
| --model google/gemma-4-26B-A4B-it \ | ||
| --save_safetensors true \ | ||
| --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \ | ||
| 'AI-ModelScope/alpaca-gpt4-data-en#500' \ | ||
| 'swift/self-cognition#500' \ | ||
| 'AI-ModelScope/LaTeX_OCR:human_handwrite#2000' \ | ||
| --load_from_cache_file true \ | ||
| --add_non_thinking_prefix true \ | ||
| --split_dataset_ratio 0.01 \ | ||
| --tuner_type full \ | ||
| --tensor_model_parallel_size 2 \ | ||
| --expert_model_parallel_size 4 \ | ||
| --pipeline_model_parallel_size 2 \ | ||
| --moe_permute_fusion true \ | ||
| --moe_grouped_gemm true \ | ||
| --moe_shared_expert_overlap true \ | ||
| --moe_aux_loss_coeff 1e-6 \ | ||
| --micro_batch_size 8 \ | ||
| --global_batch_size 16 \ | ||
| --recompute_granularity full \ | ||
| --recompute_method uniform \ | ||
| --recompute_num_layers 1 \ | ||
| --num_train_epochs 1 \ | ||
| --finetune true \ | ||
| --freeze_llm false \ | ||
| --freeze_vit true \ | ||
| --freeze_aligner true \ | ||
| --cross_entropy_loss_fusion true \ | ||
| --lr 1e-5 \ | ||
| --lr_warmup_fraction 0.05 \ | ||
| --min_lr 1e-6 \ | ||
| --output_dir megatron_output/gemma-4-26B-A4B-it \ | ||
| --eval_steps 500 \ | ||
| --save_steps 500 \ | ||
| --max_length 4096 \ | ||
| --dataloader_num_workers 8 \ | ||
| --dataset_num_proc 8 \ | ||
| --no_save_optim true \ | ||
| --no_save_rng true \ | ||
| --sequence_parallel true \ | ||
| --attention_backend unfused \ | ||
| --group_by_length true \ | ||
| --padding_free false \ | ||
| --model_author swift \ | ||
| --model_name swift-robot | ||
|
|
||
| # CUDA_VISIBLE_DEVICES=0 swift infer \ | ||
| # --model megatron_output/gemma-4-26B-A4B-it/vx-xxx/checkpoint-xxx \ | ||
| # --stream true \ | ||
| # --enable_thinking false \ | ||
| # --load_data_args true \ | ||
| # --max_new_tokens 2048 | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -62,10 +62,13 @@ def _model_cpu_forward_context(modules, | |||||||||||||||||||||||||||||
| compute_device=None, | ||||||||||||||||||||||||||||||
| share_embedding: bool = False, | ||||||||||||||||||||||||||||||
| target_device='cpu'): | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| origin_torch_dtype = next(modules[0].parameters()).dtype | ||||||||||||||||||||||||||||||
| except StopIteration: | ||||||||||||||||||||||||||||||
| origin_torch_dtype = next(modules[-1].parameters()).dtype | ||||||||||||||||||||||||||||||
| for module in modules: | ||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||
| origin_torch_dtype = next(module.parameters()).dtype | ||||||||||||||||||||||||||||||
| except StopIteration: | ||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||
| break | ||||||||||||||||||||||||||||||
|
Comment on lines
+65
to
+71
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variable
Suggested change
|
||||||||||||||||||||||||||||||
| embeddings = None | ||||||||||||||||||||||||||||||
| if share_embedding: | ||||||||||||||||||||||||||||||
| embeddings = [module for module in modules if isinstance(module, (nn.Embedding, VocabParallelEmbedding))] | ||||||||||||||||||||||||||||||
|
|
@@ -77,7 +80,7 @@ def _to_cuda_hook(module, args): | |||||||||||||||||||||||||||||
| return args | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _to_cpu_hook(module, args, output): | ||||||||||||||||||||||||||||||
| if share_embedding and module in embeddings: | ||||||||||||||||||||||||||||||
| if share_embedding and module in embeddings or 'rotaryemb' in module.__class__.__name__.lower(): | ||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The boolean expression relies on operator precedence (
Suggested change
|
||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||
| module.to(device=target_device, dtype=origin_torch_dtype) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parallelism configuration appears inconsistent with the total number of GPUs (
NPROC_PER_NODE=8). Withtensor_model_parallel_size=2andpipeline_model_parallel_size=2, the Data Parallel (DP) size is calculated as8 / (2 * 2) = 2. In Megatron-Core, the Expert Parallel (EP) size (expert_model_parallel_size) must typically be less than or equal to the DP size (EP <= DP). SettingEP=4whileDP=2will likely result in a runtime error during model initialization.