Skip to content

DURGESH716/LLaMA-8B-GRPO-Reinforcement-Learning

Repository files navigation

🚀 Python Reasoning-Enhanced Code Generation with GRPO-LoRA

Fine-tuning Llama-3-8B with Group Relative Policy Optimization (GRPO)

🌟 Innovation & Summary

VLLM_EngineCore is an advanced post-training framework designed to solve the "reasoning gap" in Large Language Models. While standard models can write code, they often fail at complex logic.

Our Innovation: We bypass the need for a separate Reward Model (common in PPO) by using GRPO. This method generates a group of outputs per prompt and calculates a "relative reward" by comparing them against each other. This forces the model to refine its internal thought process (Reasoning Trace) to maximize the probability of passing functional unit tests.

📊 Data Information

We utilize the MBPP (Mostly Basic Python Problems) dataset, which consists of ~1,000 crowd-sourced Python programming problems.

Example Pair:

  • Prompt (Natural Language): Write a function to find the sum of all odd numbers in a list.
  • Target (Python Code):
def sum_odd(numbers):
    return sum([n for n in numbers if n % 2 != 0])

# Associated Unit Tests
assert sum_odd([1, 2, 3]) == 4
assert sum_odd([2, 4, 6]) == 0

LLAMA Model Architecture:

image

The project utilizes the Llama-3-8B-Instruct backbone, enhanced via LoRA (Low-Rank Adaptation) and the GRPO RL loop.

Key Architectural Details:

  • 32 Transformer Layers: Standard decoder-only architecture.
  • Grouped-Query Attention (GQA): 32 query heads and 8 key-value heads for efficient inference during RL generation.
  • Rotary Positional Embeddings (RoPE): Base frequency of 500k to support long reasoning traces.
  • LoRA Integration: Weights are updated in the q_proj, v_proj, and mlp blocks to keep the trainable parameter count under 200MB.
  • GRPO Mechanism: Generates completions per step; advantages are computed as .

⚙️ Training Parameters

Training was conducted on a single GPU (24GB+ VRAM recommended) for 4-bit optimization.

Parameter Value Details
Epochs 3 (Iterative RL) Model sees the dataset once but explores paths
Max Steps 5,760 Total gradient update steps
Learning Rate 2e-6 Ultra-low LR for stable policy updates
Batch Size 1 Per device (effective batch controlled by )
Group Size () 8 Number of samples generated per prompt
Optimizer Paged AdamW 8-bit Memory-efficient optimization
Loss Function GRPO Proxy Loss Clipped surrogate objective
KL Penalty () 0.01 Keeps the model from drifting too far from base Llama-3

📈 Results & Evaluation

The model was evaluated across multiple checkpoints to track the impact of RL on coding accuracy. Performance was measured using the Pass@1 metric on the MBPP test set.

1. Quantitative Performance

This table tracks the performance of the Llama-3-8B-GRPO model across all checkpoints, comparing 0-Shot vs. 3-Shot Pass@1 accuracy.

Checkpoint 0-Shot Pass@1 3-Shot Pass@1 Improvement (%)
Base 0.550 0.560
500 0.556 0.558 -0.36%
1000 0.550 0.572 +2.14%
1500 0.546 0.572 +2.14%
2000 0.546 0.566 +1.07%
2500 0.552 0.568 +1.43%
3000 0.554 0.580 +3.57%
3500 0.548 0.572 +2.14%
4000 0.552 0.576 +2.86%
4500 0.544 0.574 +2.50%
5000 0.544 0.574 +2.50%
5500 0.542 0.566 +1.07%
5760 0.552 0.568 +1.43%

2. Visualization

image

3. Key Observations

  • Format Compliance: The model improved from 0% to +3.57%% in strictly following the <thought> and <answer> XML tags.
  • Peak Performance: Optimal reasoning-accuracy balance was achieved at Step 3000, after which slight over-fitting/reasoning-drift was observed.

📂 Directory Structure

VLLM_EngineCore/
├── data/mbpp/            # Localized .jsonl dataset files
├── eval_results/         # Detailed JSON performance logs
├── outputs/              # Final LoRA adapters and checkpoints
├── results/              # CSV data and plot images
├── src/                  
│   ├── train.py          # Main GRPO training script
│   ├── rewards.py        # Logic for XML format & unit test rewards
│   └── plot.py           # Visualization of reward curves
├── requirements.txt      # Dependency list
└── README.md             # Project documentation

🛠️ Getting Started

Follow these steps to replicate the results:

  1. Clone the Repo: git clone https://github.com/DURGESH716/LLaMA-8B-GRPO-Reinforcement-Learning.git
  2. Install Dependencies: pip install -r requirements.txt
  3. Download Dataset: python scripts/download_mbpp.py
  4. Execute Training: python src/train.py --model_id meta-llama/Llama-3-8B-Instruct --dataset mbpp
  5. Run Evaluation: python src/evaluate.py --checkpoint outputs/checkpoint-3000

Challenges & Solutions

  • Reward Hacking: The model initially learned to output extremely long text to "look" like reasoning. Solution: Introduced a length-penalty factor in the reward function.
  • VRAM OOM: Generating 8 completions for an 8B model exceeded 24GB VRAM. Solution: Integrated Unsloth 4-bit quantization and gradient checkpointing.
  • Formatting Failures: The model struggled to wrap code in the required <answer> tags. Solution: Added a "soft format" reward that gives partial credit for partial tags.
  • KL Divergence Spikes: Sudden training collapse due to aggressive policy updates. Solution: Decreased the learning rate and increased the warmup ratio to 10%.
  • Data Contamination: Some MBPP tasks were too simple, leading to "overfitting." Solution: Filtered the training set to prioritize problems with multiple logic branches.

Future Scope

  • Multi-Step Verifiers: Integrating a live Python interpreter to give "runtime error" rewards during training.
  • Multi-Model Distillation: Using the 70B Llama-3 as a "Teacher" reward model for the 8B "Student."
  • Cross-Domain Reasoning: Expanding beyond code to mathematical proof generation (GSM8K).
  • Direct Preference Optimization (DPO): Stacking DPO on top of GRPO for final human-alignment polishing.
  • Hardware Scaling: Migrating to Multi-GPU to support larger group sizes ().

About

LLaMA-8B-IT RL fine-tuning on GRPO (Grouped Relative priority Optimization) on Python Codes

Topics

Resources

License

Stars

Watchers

Forks

Contributors