Proving that high-quality text-to-image models can be trained on small, vertical datasets using consumer hardware.
This project demonstrates that you don't need massive datasets (millions of images) or datacenter infrastructure to train a functional text-to-image diffusion model. By focusing on a curated vertical dataset and leveraging pre-computed embeddings, we achieve practical training on a single consumer GPU.
- Vertical Dataset Viability: Train a text-to-image model on a small, curated dataset (thousands, not millions)
- Consumer Hardware: Run entirely on consumer-grade GPUs (RTX 4090, etc.)
- Pre-computed Embeddings: Avoid expensive forward passes during training by pre-computing all conditioning signals
- Flow Matching: Use modern rectified flow formulation for stable, efficient training
- Quad Conditioning: Combine text (T5), visual (DINOv3), and pose (DWPose) signals for precise control
- Production: 768 hidden, 18 layers, 12 heads (~400M parameters)
- Patch-based: 2×2 patches in pixel space for high detail
- Pixel-space training: Model predicts x0 (RGB pixels) directly — no VAE encoder/decoder
- Quad conditioning:
- Text via T5-Large embeddings (1024-dim, 512 tokens)
- Visual style via DINOv3 CLS token (1024-dim)
- Spatial layout via DINOv3 patch embeddings (~4000 tokens × 1024-dim)
- Pose via DWPose whole-body keypoints (133 joints × 3-dim [x, y, confidence])
- Dynamic positional encoding: Supports variable aspect ratios
- Algorithm: Flow matching (continuous-time diffusion)
- Prediction type: x_prediction (model outputs x0 directly)
- Loss: MSE between predicted and target pixels
- Timestep sampling: Logit-normal distribution (focuses on mid-diffusion)
- Timestep Encoding: High-frequency 1000x scaled sinusoidal embeddings for stable flow
- CFG strategy:
- Training dropout (7 mutually exclusive categories): 10% uncond, 25% text-only, 5% DINO-CLS-only, 5% DINO-patches-only, 10% drop-pose, 5% pose-only, 40% all present
- Self-guidance (default with TREAD): 2 passes — dense (all tokens) vs routed (50% tokens). Single
guidance_scale(default 3.0) - Dual CFG (fallback): 3 passes — unconditional + text-only + DINO-only
- 133 whole-body keypoints: 17 body + 6 feet + 68 face + 42 hands (COCO-WholeBody format)
- Projection: MLP (3→768→768 with GELU) + learned per-joint type embeddings
- Confidence masking: Joints below threshold (0.05) replaced with learned [NULL_POSE] token
- CFG dropout: Dedicated pose dropout categories (independent of text/DINO dropout)
- Cross-attention: Pose tokens appended to DINO patches in conditioning sequence
A curated collection of images from a vertical domain (specific subject matter, consistent quality).
All preprocessing is done via scripts/generate_approved_image_dataset.py:
Images are assigned to 7 buckets (~1 megapixel each) to enable variable aspect ratio training without distortion:
1024×1024 (1.00) - Square
1216×832 (1.46) - Landscape
832×1216 (0.68) - Portrait
1280×768 (1.67) - Wide landscape
768×1280 (0.60) - Tall portrait
1344×704 (1.91) - Very wide
704×1344 (0.52) - Very tall
Bucketing algorithm:
- Assign image to closest aspect ratio bucket
- Resize-to-cover (scale to fill bucket, no black bars)
- Center-crop to exact bucket dimensions
- All subsequent processing uses this bucketed view
Dense, objective image descriptions generated using Google Gemma3:27b:
- Physical attributes, composition, lighting, colors
- No subjective interpretation
- ~150-200 words per image
- Example: "A fair-skinned woman with a slender build and visible clavicle is positioned in a close-up, frontal portrait. Her dark, curly hair frames her face..."
All embeddings are pre-computed and stored to avoid expensive forward passes during training:
- Model: Google T5-Large
- Output: 512 tokens × 1024 dimensions (fp16)
- Storage:
data/derived/t5/*.npy(~1 MB per image) - Caption → T5 encoder → fixed-length sequence embeddings
- Model: Facebook DINOv3-ViT-L/16 (304M parameters)
- CLS token: Single 1024-dim global feature vector
- Captures color palette, style, composition
- Storage:
data/derived/dinov3/*.npy(~4 KB per image)
- Patch embeddings (variable spatial patches):
- ~4000 × 1024 dimensions for spatial conditioning
- Storage:
data/derived/dinov3_patches/*.npy(~0.78 MB per image)
- Format: RGB float16, range [0, 1]
- Output: 3 channels × H × W (bucket-specific dimensions)
- Storage:
data/derived/image/*.npy(varies by bucket)
- Model: DWPose (ONNX) — YOLOx-L detector + DW-LL pose estimator
- Output: 133 joints × 3 dimensions [x_norm, y_norm, confidence] (float16)
- Normalization: x_norm = (2x/W) - 1, y_norm = (2y/H) - 1 (relative to bucket)
- Storage:
data/derived/pose/*.npy(~1.6 KB per image)
Embeddings are packed into tar shards for efficient streaming during training:
data/shards/10000/
bucket_1024x1024/
shard-000000.tar # Contains .npy files for each modality
shard-000001.tar
bucket_1216x832/
...
Each shard contains:
{image_id}.image.npy- RGB pixel data (3, H, W){image_id}.dinov3.npy- DINOv3 CLS token{image_id}.dinov3_patches.npy- DINOv3 spatial patches{image_id}.t5h.npy- T5 text embeddings{image_id}.t5m.npy- T5 attention mask{image_id}.pose.npy- DWPose keypoints (133, 3){image_id}.json- Metadata (bucket, dimensions, caption)
Per image:
- Pixel data: ~6 MB (depends on bucket)
- T5 embeddings: ~1 MB
- DINOv3 CLS: ~4 KB
- DINOv3 patches: ~0.78 MB
- DWPose keypoints: ~1.6 KB
- Total: ~8 MB per image
Full dataset:
- Embeddings + original images
- Storage requirements scale linearly with dataset size
Baseline (384 hidden, 12 layers):
- GPU: RTX 4090 (24 GB VRAM) or similar
- RAM: 32 GB system memory
- Storage: Sufficient SSD for embeddings
- Training speed: ~1.6 it/s with gradient accumulation
Production (768 hidden, 18 layers):
- GPU: RTX 4090 or similar (24 GB VRAM)
- RAM: 64 GB system memory recommended
- Storage: Sufficient SSD for embeddings
- Training speed: ~1.0 it/s with gradient accumulation + gradient checkpointing
model:
hidden_size: 768
depth: 18
num_heads: 12
patch_size: 2
in_channels: 3 # RGB pixels
prediction_type: x_prediction # Pixel-space
training:
total_steps: 6000 # Optimizer steps (accumulated)
batch_size: 1
grad_accumulation_steps: 256 # Effective batch = 256
learning_rate: 3e-4 → 1e-6 (cosine decay)
warmup_steps: 500
mixed_precision: bfloat16
gradient_checkpointing: true # Saves ~3-4× memory
optimizer:
type: Muon # Hybrid Muon + AdamW (or "AdamW" for pure AdamW)
muon:
momentum: 0.95
nesterov: true
ns_steps: 5
adjust_lr_fn: match_rms_adamw
cfg_dropout:
p_uncond: 0.10
p_text_only: 0.25
p_dino_cls_only: 0.05
p_dino_patches_only: 0.05
p_drop_pose: 0.10
p_pose_only: 0.05
resolution_schedule: # Train at lower res first (optional)
- until_step: 3000
scale: 0.5 # Half resolution (4× fewer tokens)
- until_step: 6000
scale: 1.0 # Full resolution
repa:
enabled: true
weight: 0.5 # REPA loss weight
block_index: -1 # -1 = depth // 2 (block 9)
loss_type: cosine # Cosine similarity alignment
tread:
enabled: true # Token routing for throughput
routing_probability: 0.5
self_guidance: true # Use self-guidance instead of dual CFG
guidance_scale: 3.0- Gradient Accumulation: Effective batch size of 256 with batch_size=1
- FlashAttention (SDPA): Uses memory-efficient
scaled_dot_product_attentionto avoid OOM on large cross-attention sequences - Mixed Precision: bfloat16 training (required for flow matching stability)
- Gradient Checkpointing: Trade 20-30% speed for 3-4× memory savings
- Bucket-aware Batching: Sample from aspect ratio buckets proportionally
- EMA: Exponential moving average of weights (decay=0.9999) with 500-step warmup
- REPA (REPresentation Alignment): Auxiliary loss aligning transformer hidden states with DINOv3 patch features at the middle block, improving convergence and representation quality (weight=0.5, cosine similarity)
- TREAD (Token Routing): Randomly routes 50% of latent tokens past middle blocks (1→depth-2), effectively halving compute for 16 of 18 blocks. Parameter-free — adds zero new weights. Pairs with self-guidance sampling (2 passes instead of 3-pass dual CFG)
- Muon Optimizer: Hybrid Muon + AdamW — all 2D weight matrices (~237M params) use Muon's Newton-Schulz orthogonalization for geometry-aware updates; non-2D params (convs, biases, ~0.1M) stay on AdamW. Uses
adjust_lr_fn="match_rms_adamw"so both optimizers share the same LR schedule. - Resolution Scheduling: Train at lower resolution first (e.g., 0.5× spatial scale = 4× fewer tokens), then transition to full resolution.
- DWPose Conditioning: Whole-body pose keypoints with dedicated CFG dropout, confidence masking, and learned [NULL_POSE] embeddings.
- Horizontal flip: 50% probability (applied to pixel images)
- Note: T5 embeddings are NOT flipped (known limitation - would require caption rewriting)
Comprehensive validation suite runs periodically during training, but can also be executed on demand for specific checkpoints.
If training is interrupted during a validation phase, or if you want to back-test an older checkpoint with new validation logic, you can run the validation suite standalone:
python scripts/run_checkpoint_validation.py \
--config experiments/2026-02-22_1227/config.yaml \
--checkpoint experiments/2026-02-22_1227/checkpoints/checkpoint_step001200.ptThis loads the model and EMA weights, runs the full validation suite (Reconstruction, DINO Swap, CFG Divergence, Text Manipulation), and saves the LPIPS scores and images to the corresponding validation_outputs/step{N}/ directory within the experiment folder.
- Goal: Test model's ability to reconstruct training images
- Method: Use original image's DINO + T5 embeddings as "perfect" conditioning
- Metric: LPIPS (perceptual similarity, lower is better)
- Expected: <0.3 excellent, 0.3-0.5 good, 0.5-0.7 blurry, >0.7 poor
- Goal: Test visual style transfer
- Method: Swap DINO embeddings between image pairs, keep original text
- Expected: Generated image should match swapped DINO's style but original text's content
- Goal: Test text conditioning strength
- Method: Modify text embeddings (e.g., "brunette" → "blonde"), keep same DINO
- Expected: Changes in generation should match text modifications
- Goal: Detect overfitting or mode collapse
- Method: Track distribution of LPIPS scores across validation set
- Expected: Stable distribution, no sudden spikes
- Size: 25 images per test
- Sampling: Deterministic (seeded at 42) for reproducibility across runs
- Source: Separate holdout set from training data
- Outputs: Saved to
validation_outputs/step{N}/
- Training:
- Loss (flow matching MSE + REPA alignment)
- REPA loss (cosine alignment with DINOv3 patches)
- LPIPS loss (perceptual quality on decoded crops, when enabled)
- Gradient norm (overall + per-layer)
- Velocity norm (RMS of predicted velocity vectors, should be ~1.0)
- Learning rate
- System:
sys/iter_per_sec— wall-clock throughputsys/active_tokens— tokens processed per forward pass (quantifies TREAD savings)memory/peak_vram_gb— peak GPU memory (max_memory_allocated)memory/vram_allocated_gb,memory/vram_reserved_gb
- Validation:
- LPIPS (reconstruction, DINO swap, text manipulation)
- Per-test statistics
Quick 4-image generation every 100 steps:
- Fixed deterministic samples for consistency
- Saved to
experiments/{timestamp}/visual_debug/step{N}/ - Useful for spotting training issues early
- Saved every 500 steps to
experiments/{timestamp}/checkpoints/ - Includes:
- Model weights
- Optimizer state
- EMA weights
- RNG states (Python, NumPy, PyTorch CPU/CUDA)
- Training step and epoch
- Checkpoint size: ~3.6 GB (768 hidden model)
Production Training (768 hidden, 18 layers)
- ✅ Pixel-Space Training: Direct RGB prediction (x_prediction) — no VAE encoder/decoder needed
- ✅ Quad Conditioning: Text (T5-Large) + DINOv3 CLS + DINOv3 patches + DWPose keypoints
- ✅ DWPose Integration: 133 whole-body keypoints with confidence masking, dedicated CFG dropout, learned [NULL_POSE] embeddings
- ✅ Memory Optimized: FlashAttention (SDPA) and Gradient Checkpointing enable training at 1024px on 24GB GPUs
- ✅ REPA Alignment: Auxiliary loss aligns hidden states with DINOv3 teacher features for faster convergence
- ✅ TREAD Token Routing: Routes 50% of tokens past middle blocks for ~2× throughput, with self-guidance sampling
- ✅ Muon Optimizer: Hybrid Muon + AdamW for geometry-aware weight updates
- ✅ Resolution Scheduling: Multi-phase training at lower resolution first for faster convergence
- 🚧 In progress: Training on curated dataset, scaling to 70k images
-
Caption Flipping: Horizontal flip augmentation doesn't modify T5 embeddings
- Left/right mentions in captions don't swap with image
- Would require NLP caption rewriting (future work)
-
Training Data Size: Preparing 70k image dataset.
prx-tg/
├── production/ # Core training code
│ ├── train.py # Training loop (optimizer-step based)
│ ├── model.py # NanoDiT with SDPA & high-freq timesteps
│ ├── data.py # WebDataset loading, bucket-specific stats
│ ├── validate.py # Validation suite
│ ├── visual_debug.py # Quick sample generation
│ ├── sample.py # Inference sampler (Euler)
│ └── config.yaml # Training configuration (14-day schedule)
│
├── scripts/
│ ├── generate_approved_image_dataset.py # Data preprocessing
│ └── test_*.py # Diagnostic scripts
│
├── data/
│ ├── approved/ # Original images
│ ├── derived/ # Pre-computed embeddings
│ └── shards/15000/ # WebDataset tar shards
│
├── experiments/ # Unified experiment tracking (checkpoints, logs, visuals)
└── docs/ # Documentation and plans
# Generate all embeddings
python -m scripts.generate_approved_image_dataset \
--device cuda \
--pass-filter all \
--verbose
# This will:
# - Caption images with Gemma3:27b
# - Extract T5-Large text embeddings
# - Extract DINOv3 visual embeddings (CLS + patches)
# - Extract DWPose keypoints (133 whole-body joints)
# - Save bucketed RGB pixel images
# - Create WebDataset shards# Start fresh training
python -m production.train_production
# Resume from checkpoint
python -m production.train_production \
--resume checkpoints/checkpoint_step010000.pt# Generate images (implementation in progress)
python -m production.sample \
--checkpoint checkpoints/checkpoint_latest.pt \
--prompt "Your text prompt here" \
--reference-image path/to/style.jpg \
--output output.pngThe ablation grid isolates the contribution of each key technique. Each run trains for 5,000 steps on the same 7k FFHQ subset.
| Run | TREAD | Optimizer | REPA | Purpose |
|---|---|---|---|---|
| A | ✗ | AdamW | ✗ | Baseline (no optimizations) |
| B | ✓ | AdamW | ✗ | Proves TREAD's VRAM/speed value |
| C | ✓ | Muon | ✗ | Proves Muon's convergence advantage |
| D | ✓ | Muon | ✓ | Full stack — proves REPA's quality boost |
# Run the full grid (A → B → C → D)
bash scripts/run_ablations.sh
# Run a subset
bash scripts/run_ablations.sh B D
# Select GPU
GPU=1 bash scripts/run_ablations.shConfigs are in experiments/ablations/config_{A,B,C,D}_*.yaml. Results are logged to TensorBoard under each experiment's tensorboard/ directory.
- A vs B:
memory/peak_vram_gbandsys/iter_per_sec— quantifies TREAD's compute savings - B vs C:
train/lossslope during first 5,000 steps — proves Muon's faster convergence - C vs D:
validation/reconstruction_lpips— proves REPA's perceptual quality improvement
- Python 3.13+
- PyTorch 2.6+ with CUDA
- transformers (HuggingFace)
- webdataset
- lpips
- Pillow
- numpy
See full environment in .venv/ (not committed).
See docs/prx-part3-analysis.md for a detailed analysis of techniques from Photoroom's PRX Part 3.
-
Larger Dataset: Expand dataset size towards 70k target.
-
Latent REPA: Replace DINOv3 teacher alignment with self-supervised latent alignment (future exploration).
(To be determined)
- T5-Large: Google (text encoder)
- DINOv3: Facebook AI Research (visual encoder)
- DWPose: Open-source whole-body pose estimation (ONNX)
- Gemma3:27b: Google (caption generation)
- Flow Matching: Inspired by Stable Diffusion 3 and Flux.1 training methodology
This is a research/hobby project demonstrating that high-quality generative models can be trained on consumer hardware with curated vertical datasets. Not intended for commercial use.