Skip to content

Sayed-Hossein-Hosseini/RhythmAttention_Hybrid_CNN-Transformer_Architecture_for_Arrhythmia_Classification

Repository files navigation

🫀 RhythmAttention: Hybrid CNN-Transformer Architecture for Arrhythmia Classification

Python TensorFlow Deep Learning License: MIT

RhythmAttention is a high-performance Deep Learning framework designed for the automated classification of Electrocardiogram (ECG) signals. By merging the spatial feature extraction capabilities of Residual CNNs with the long-range temporal modeling of Transformers (Self-Attention), this architecture achieves state-of-the-art robustness in identifying cardiac arrhythmias.


🚀 Key Innovation: The Hybrid Approach

Standard CNNs excel at capturing morphological features (P-QRS-T wave shapes), while Transformers excel at understanding the "rhythm" over time. RhythmAttention combines both to solve the most difficult challenges in ECG analysis:

  1. Extreme Class Imbalance: Effectively managing the 30:1 ratio between Normal and Minority beats (like Fusion or Supraventricular beats).
  2. Morphological Mimicry: Distinguishing between Supraventricular (S) and Normal (N) beats, which often appear nearly identical to the human eye.
  3. Stability in Learning: Utilizing Global Gradient Clipping and Stable Focal Loss to prevent divergence during training on noisy physiological data.

🏗️ Model Architecture

The model follows a specialized three-stage pipeline:

1. Residual CNN Feature Extractor

  • 1D ResNet Blocks: Extracts local morphological features while preventing gradient vanishing.
  • Spatial Dropout: Specifically designed for time-series to drop entire feature maps, forcing the model to learn more generalized patterns.

2. Transformer Contextualizer

  • Positional Embedding: Since Transformers are permutation-invariant, we inject temporal coordinates to help the model understand the sequence of the heart cycle.
  • Multi-Head Self-Attention: Allows the model to focus on different parts of the signal (e.g., the relationship between the P-wave and R-peak) simultaneously.

3. Classification Head

  • Dual Pooling: Concatenates Global Average and Global Max Pooling to preserve both the most prominent and the most consistent features before the final Softmax layer.

🧪 Advanced Training Techniques

  • Stable Focal Loss: A specialized loss function that penalizes the model more for misclassifying "hard" minority examples, forcing it to learn beyond the majority class.
  • Physiological Augmentation: Online data augmentation including Gaussian Noise, Amplitude Scaling, and Time-Shifting (±10 samples) to simulate real-world sensor variability.
  • Weighted Random Sampling: Ensures every training batch is balanced, preventing the model from becoming biased toward Normal (N) beats.

📊 Visual Results & Evaluation

1. Learning Curves

In this section, you can observe the stability of the model during training. Thanks to L2 Regularization and Global Gradient Clipping, the training and validation curves converge smoothly without extreme oscillations.

Training History
Figure 1: Loss and Accuracy curves for Training vs. Validation sets.

2. Model Confusion Matrix

The Confusion Matrix below illustrates the model's performance across all 5 classes. Note the high precision in the 'N' and 'Q' classes, and the improved sensitivity in the 'F' and 'V' categories.

Confusion Matrix
Figure 2: Confusion Matrix showcasing the identification of minority classes despite morphological similarities.


📊 Performance Metrics (MIT-BIH Dataset)

Evaluation performed on the MIT-BIH Arrhythmia Database:

Class Description Precision Recall F1-Score
N Normal Beat 0.97 0.96 0.96
S Supraventricular 0.42 0.35 0.38
V Ventricular Escape 0.85 0.82 0.83
F Fusion Beat 0.30 0.58 0.40
Q Unclassified / Unknown 0.96 0.96 0.96
  • Overall Accuracy: ~93%
  • Macro Avg F1: ~0.71 (Reflecting strong performance across all categories)

🛠️ Getting Started

Prerequisites

pip install tensorflow numpy pandas matplotlib seaborn scikit-learn

Usage

  1. Place mitbih_train.csv and mitbih_test.csv in the Datasets/ directory.
  2. Run the main script to train the model.
  3. The best model weights will be saved automatically as RhythmAttention_best.keras.
from model import build_rhythm_attention_model

# Load Architecture
model = build_rhythm_attention_model(input_shape=(187, 1), num_classes=5)
# Load Weights
model.load_weights('models/RhythmAttention_best.keras')

🏁 Conclusion

The RhythmAttention project successfully demonstrates that a Hybrid CNN-Transformer architecture is significantly more effective for ECG classification than traditional CNN or RNN models alone.

Key takeaways from this implementation include:

  • Synergy of Features: While the CNN layers successfully extract local morphology (P-QRS-T complexes), the Transformer layers provide a "global view" of the heart's rhythm, which is crucial for identifying arrhythmias.
  • Robustness to Imbalance: Through the use of Focal Loss and Weighted Sampling, the model effectively learns to identify rare but life-critical beats (like Fusion beats) without being overwhelmed by the majority "Normal" class.
  • Stability: The inclusion of Spatial Dropout and Residual Connections ensures that the model remains stable and avoids the common pitfall of "Exploding Gradients" often seen in complex Attention-based architectures.

Overall, RhythmAttention provides a reliable, scalable, and high-precision foundation for the next generation of automated cardiac diagnostic tools.


📝 Author

Sayyed Hossein Hosseini
Deep Learning Researcher & Healthcare AI Enthusiast

About

This project is part of the Attention-ECG series, adapting the hybrid architecture previously used in PulseAttention for the specific task of Arrhythmia classification using the MIT-BIH dataset.

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors