I reviewed the Stage 3 training/loss implementation in:
/home/runner/work/Comic-Analysis/Comic-Analysis/src/version2/train_stage3.py/home/runner/work/Comic-Analysis/Comic-Analysis/src/version2/stage3_panel_features_framework.py/home/runner/work/Comic-Analysis/Comic-Analysis/src/version2/stage3_dataset.py/home/runner/work/Comic-Analysis/Comic-Analysis/src/version2/README_STAGE3.md/home/runner/work/Comic-Analysis/Comic-Analysis/src/version2/test_stage3_collapse.py
Note: I did not find train_stage3_model.py in this repository. The active Stage 3 training file appears to be train_stage3.py.
The training objective is a weighted sum:
- Contrastive loss (
contrastive_loss) on panel embeddings, with positives defined as panels from the same page. - Reconstruction loss (
reconstruction_loss) that predicts one masked panel embedding from the mean of the remaining panel embeddings. - Modality alignment loss (
modality_alignment_loss) using image-to-text cross-entropy retrieval on panels where both image and text are present.
Combined in training as:
total_loss = 1.0 * contrastive_loss + 0.5 * reconstruction_loss + 0.3 * alignment_loss
This matches the Stage 3 README description at a high level.
- The objectives are conceptually aligned with Stage 3 goals (intra-page structure + multimodal alignment).
- Modality masks are used to avoid aligning missing modalities.
- Contrastive implementation follows a supervised-contrastive style with same-page positives and self-pair exclusion.
- Safety checks avoid crashes when too few valid panels exist.
In validate(), model.eval() is called, but objectives.eval() is not. Since Stage3TrainingObjectives contains Dropout in reconstruction_head, validation loss likely includes stochastic dropout behavior.
Impact: noisy/unstable validation metrics and potentially unreliable best-checkpoint selection.
What I would change (in code, if editing were allowed):
- Set
objectives.eval()in validation. - Set
objectives.train()only during training.
reconstruction_loss computes MSE between predicted context-based embedding and target embedding from the same current forward pass, with gradients flowing through both sides.
Impact: the model can reduce loss by moving both context and target embeddings together rather than learning stronger predictive structure, which can encourage over-smoothing.
What I would change:
- Detach the target embedding in reconstruction loss (predict a stable target).
- Optionally detach context too in an ablation to test stability vs learning signal.
Current alignment is CrossEntropy(logits=image·text^T, labels=diag), i.e., image retrieval of text only.
Impact: weaker bidirectional alignment; text->image retrieval quality can lag.
What I would change:
- Use symmetric contrastive alignment: average of vision->text and text->vision CE losses.
Weights are hardcoded in both train_epoch() and validate().
Impact: difficult tuning and reproducibility across datasets.
What I would change:
- Expose
contrastive_weight,reconstruction_weight, andalignment_weightvia args/config.
Multiple branches return zero loss when there are <2 valid entries.
Impact: effective batch objective can become weak depending on dataset panel distribution and text sparsity.
What I would change:
- Log per-component “active sample counts”.
- Consider sampling strategy to ensure enough multi-panel and text-present samples per batch.
Stage3TrainingObjectives accepts temperature, but the script instantiation does not expose CLI/config control for it.
Impact: reduced experiment control and reproducibility.
What I would change:
- Add CLI/config plumbed temperature for all contrastive/alignment logits.
- Contrastive complexity scales quadratically with number of valid panels in batch (
O(V^2)similarity matrix). - Gradient clipping is applied only to
model.parameters()while optimizer includesobjectives.parameters()too. test_stage3_collapse.pyis useful and should remain part of regular regression checks after loss changes.
- Validation mode fix (
objectives.eval()in validation). - Reconstruction target detachment to stabilize representation learning.
- Symmetric alignment loss (vision<->text).
- Configurable loss weights + temperature plumbing.
- Add per-loss diagnostics (active pairs/samples, each component magnitude).
# validation
model.eval()
objectives.eval()
# reconstruction target stabilization
target = valid_panels[mask_idx].detach()
# symmetric alignment
loss_image_to_text = CE(sim(vision, text), diag_labels)
loss_text_to_image = CE(sim(vision, text).T, diag_labels)
alignment = 0.5 * (loss_image_to_text + loss_text_to_image)
# weighted total from config
total_loss = contrastive_weight * contrastive + reconstruction_weight * reconstruction + alignment_weight * alignment
The current Stage 3 loss framework is directionally strong, but there are a few high-impact issues in objective handling and optimization details (especially validation mode and reconstruction target gradients) that could materially affect stability and model quality. Addressing those first should improve confidence in training dynamics without changing the broader architecture.