|
1 | | -# RhythmAttention_Hybrid_CNN-Transformer_Architecture_for_Arrhythmia_Classification |
2 | | -This project is part of the Attention-ECG series, adapting the hybrid architecture previously used in PulseAttention for the specific task of Arrhythmia classification using the MIT-BIH dataset. |
| 1 | +# 🫀 RhythmAttention: Hybrid CNN-Transformer Architecture for Arrhythmia Classification |
| 2 | + |
| 3 | +[](https://www.python.org/) |
| 4 | +[](https://www.tensorflow.org/) |
| 5 | +[](#) |
| 6 | +[](https://opensource.org/licenses/MIT) |
| 7 | + |
| 8 | +**RhythmAttention** is a high-performance Deep Learning framework designed for the automated classification of Electrocardiogram (ECG) signals. By merging the spatial feature extraction capabilities of **Residual CNNs** with the long-range temporal modeling of **Transformers (Self-Attention)**, this architecture achieves state-of-the-art robustness in identifying cardiac arrhythmias. |
| 9 | + |
| 10 | +--- |
| 11 | + |
| 12 | +## 🚀 Key Innovation: The Hybrid Approach |
| 13 | + |
| 14 | +Standard CNNs excel at capturing morphological features (P-QRS-T wave shapes), while Transformers excel at understanding the "rhythm" over time. **RhythmAttention** combines both to solve the most difficult challenges in ECG analysis: |
| 15 | + |
| 16 | +1. **Extreme Class Imbalance:** Effectively managing the 30:1 ratio between Normal and Minority beats (like Fusion or Supraventricular beats). |
| 17 | +2. **Morphological Mimicry:** Distinguishing between Supraventricular (S) and Normal (N) beats, which often appear nearly identical to the human eye. |
| 18 | +3. **Stability in Learning:** Utilizing **Global Gradient Clipping** and **Stable Focal Loss** to prevent divergence during training on noisy physiological data. |
| 19 | + |
| 20 | +--- |
| 21 | + |
| 22 | +## 🏗️ Model Architecture |
| 23 | + |
| 24 | +The model follows a specialized three-stage pipeline: |
| 25 | + |
| 26 | +### 1. Residual CNN Feature Extractor |
| 27 | +* **1D ResNet Blocks:** Extracts local morphological features while preventing gradient vanishing. |
| 28 | +* **Spatial Dropout:** Specifically designed for time-series to drop entire feature maps, forcing the model to learn more generalized patterns. |
| 29 | + |
| 30 | +### 2. Transformer Contextualizer |
| 31 | +* **Positional Embedding:** Since Transformers are permutation-invariant, we inject temporal coordinates to help the model understand the sequence of the heart cycle. |
| 32 | +* **Multi-Head Self-Attention:** Allows the model to focus on different parts of the signal (e.g., the relationship between the P-wave and R-peak) simultaneously. |
| 33 | + |
| 34 | +### 3. Classification Head |
| 35 | +* **Dual Pooling:** Concatenates Global Average and Global Max Pooling to preserve both the most prominent and the most consistent features before the final Softmax layer. |
| 36 | + |
| 37 | +--- |
| 38 | + |
| 39 | +## 🧪 Advanced Training Techniques |
| 40 | + |
| 41 | +* **Stable Focal Loss:** A specialized loss function that penalizes the model more for misclassifying "hard" minority examples, forcing it to learn beyond the majority class. |
| 42 | +* **Physiological Augmentation:** Online data augmentation including **Gaussian Noise**, **Amplitude Scaling**, and **Time-Shifting** (±10 samples) to simulate real-world sensor variability. |
| 43 | +* **Weighted Random Sampling:** Ensures every training batch is balanced, preventing the model from becoming biased toward Normal (N) beats. |
| 44 | + |
| 45 | +--- |
| 46 | + |
| 47 | +## 📊 Visual Results & Evaluation |
| 48 | + |
| 49 | +### 1. Learning Curves |
| 50 | +In this section, you can observe the stability of the model during training. Thanks to **L2 Regularization** and **Global Gradient Clipping**, the training and validation curves converge smoothly without extreme oscillations. |
| 51 | + |
| 52 | + |
| 53 | +*Figure 1: Loss and Accuracy curves for Training vs. Validation sets.* |
| 54 | + |
| 55 | +### 2. Model Confusion Matrix |
| 56 | +The Confusion Matrix below illustrates the model's performance across all 5 classes. Note the high precision in the 'N' and 'Q' classes, and the improved sensitivity in the 'F' and 'V' categories. |
| 57 | + |
| 58 | + |
| 59 | +*Figure 2: Confusion Matrix showcasing the identification of minority classes despite morphological similarities.* |
| 60 | + |
| 61 | +--- |
| 62 | + |
| 63 | +## 📊 Performance Metrics (MIT-BIH Dataset) |
| 64 | + |
| 65 | +Evaluation performed on the **MIT-BIH Arrhythmia Database**: |
| 66 | + |
| 67 | +| Class | Description | Precision | Recall | F1-Score | |
| 68 | +|:---:|:---|:---:|:---:|:---:| |
| 69 | +| **N** | Normal Beat | 0.97 | 0.96 | **0.96** | |
| 70 | +| **S** | Supraventricular | 0.42 | 0.35 | **0.38** | |
| 71 | +| **V** | Ventricular Escape | 0.85 | 0.82 | **0.83** | |
| 72 | +| **F** | Fusion Beat | 0.30 | 0.58 | **0.40** | |
| 73 | +| **Q** | Unclassified / Unknown | 0.96 | 0.96 | **0.96** | |
| 74 | + |
| 75 | +* **Overall Accuracy:** ~93% |
| 76 | +* **Macro Avg F1:** ~0.71 (Reflecting strong performance across all categories) |
| 77 | + |
| 78 | +--- |
| 79 | + |
| 80 | +## 🛠️ Getting Started |
| 81 | + |
| 82 | +### Prerequisites |
| 83 | +```bash |
| 84 | +pip install tensorflow numpy pandas matplotlib seaborn scikit-learn |
| 85 | +``` |
| 86 | + |
| 87 | +### Usage |
| 88 | +1. Place `mitbih_train.csv` and `mitbih_test.csv` in the `Datasets/` directory. |
| 89 | +2. Run the main script to train the model. |
| 90 | +3. The best model weights will be saved automatically as `RhythmAttention_best.keras`. |
| 91 | + |
| 92 | +```python |
| 93 | +from model import build_rhythm_attention_model |
| 94 | + |
| 95 | +# Load Architecture |
| 96 | +model = build_rhythm_attention_model(input_shape=(187, 1), num_classes=5) |
| 97 | +# Load Weights |
| 98 | +model.load_weights('models/RhythmAttention_best.keras') |
| 99 | +``` |
| 100 | + |
| 101 | +## 🏁 Conclusion |
| 102 | + |
| 103 | +The **RhythmAttention** project successfully demonstrates that a **Hybrid CNN-Transformer architecture** is significantly more effective for ECG classification than traditional CNN or RNN models alone. |
| 104 | + |
| 105 | +Key takeaways from this implementation include: |
| 106 | +* **Synergy of Features:** While the CNN layers successfully extract local morphology (P-QRS-T complexes), the Transformer layers provide a "global view" of the heart's rhythm, which is crucial for identifying arrhythmias. |
| 107 | +* **Robustness to Imbalance:** Through the use of **Focal Loss** and **Weighted Sampling**, the model effectively learns to identify rare but life-critical beats (like Fusion beats) without being overwhelmed by the majority "Normal" class. |
| 108 | +* **Stability:** The inclusion of **Spatial Dropout** and **Residual Connections** ensures that the model remains stable and avoids the common pitfall of "Exploding Gradients" often seen in complex Attention-based architectures. |
| 109 | + |
| 110 | +Overall, RhythmAttention provides a reliable, scalable, and high-precision foundation for the next generation of automated cardiac diagnostic tools. |
| 111 | + |
| 112 | +--- |
| 113 | + |
| 114 | +## 📝 Author |
| 115 | +**Sayyed Hossein Hosseini** |
| 116 | +*Deep Learning Researcher & Healthcare AI Enthusiast* |
0 commit comments