Skip to content

Commit ff347bb

Browse files
2 parents 8affd1d + 38c34fa commit ff347bb

1 file changed

Lines changed: 116 additions & 2 deletions

File tree

README.md

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,116 @@
1-
# RhythmAttention_Hybrid_CNN-Transformer_Architecture_for_Arrhythmia_Classification
2-
This project is part of the Attention-ECG series, adapting the hybrid architecture previously used in PulseAttention for the specific task of Arrhythmia classification using the MIT-BIH dataset.
1+
# 🫀 RhythmAttention: Hybrid CNN-Transformer Architecture for Arrhythmia Classification
2+
3+
[![Python](https://img.shields.io/badge/Python-3.9+-blue?logo=python)](https://www.python.org/)
4+
[![TensorFlow](https://img.shields.io/badge/TensorFlow-2.15+-orange?logo=tensorflow)](https://www.tensorflow.org/)
5+
[![Deep Learning](https://img.shields.io/badge/Area-Healthcare_AI-red)](#)
6+
[![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT)
7+
8+
**RhythmAttention** is a high-performance Deep Learning framework designed for the automated classification of Electrocardiogram (ECG) signals. By merging the spatial feature extraction capabilities of **Residual CNNs** with the long-range temporal modeling of **Transformers (Self-Attention)**, this architecture achieves state-of-the-art robustness in identifying cardiac arrhythmias.
9+
10+
---
11+
12+
## 🚀 Key Innovation: The Hybrid Approach
13+
14+
Standard CNNs excel at capturing morphological features (P-QRS-T wave shapes), while Transformers excel at understanding the "rhythm" over time. **RhythmAttention** combines both to solve the most difficult challenges in ECG analysis:
15+
16+
1. **Extreme Class Imbalance:** Effectively managing the 30:1 ratio between Normal and Minority beats (like Fusion or Supraventricular beats).
17+
2. **Morphological Mimicry:** Distinguishing between Supraventricular (S) and Normal (N) beats, which often appear nearly identical to the human eye.
18+
3. **Stability in Learning:** Utilizing **Global Gradient Clipping** and **Stable Focal Loss** to prevent divergence during training on noisy physiological data.
19+
20+
---
21+
22+
## 🏗️ Model Architecture
23+
24+
The model follows a specialized three-stage pipeline:
25+
26+
### 1. Residual CNN Feature Extractor
27+
* **1D ResNet Blocks:** Extracts local morphological features while preventing gradient vanishing.
28+
* **Spatial Dropout:** Specifically designed for time-series to drop entire feature maps, forcing the model to learn more generalized patterns.
29+
30+
### 2. Transformer Contextualizer
31+
* **Positional Embedding:** Since Transformers are permutation-invariant, we inject temporal coordinates to help the model understand the sequence of the heart cycle.
32+
* **Multi-Head Self-Attention:** Allows the model to focus on different parts of the signal (e.g., the relationship between the P-wave and R-peak) simultaneously.
33+
34+
### 3. Classification Head
35+
* **Dual Pooling:** Concatenates Global Average and Global Max Pooling to preserve both the most prominent and the most consistent features before the final Softmax layer.
36+
37+
---
38+
39+
## 🧪 Advanced Training Techniques
40+
41+
* **Stable Focal Loss:** A specialized loss function that penalizes the model more for misclassifying "hard" minority examples, forcing it to learn beyond the majority class.
42+
* **Physiological Augmentation:** Online data augmentation including **Gaussian Noise**, **Amplitude Scaling**, and **Time-Shifting** (±10 samples) to simulate real-world sensor variability.
43+
* **Weighted Random Sampling:** Ensures every training batch is balanced, preventing the model from becoming biased toward Normal (N) beats.
44+
45+
---
46+
47+
## 📊 Visual Results & Evaluation
48+
49+
### 1. Learning Curves
50+
In this section, you can observe the stability of the model during training. Thanks to **L2 Regularization** and **Global Gradient Clipping**, the training and validation curves converge smoothly without extreme oscillations.
51+
52+
![Training History](results/learning_curves.png)
53+
*Figure 1: Loss and Accuracy curves for Training vs. Validation sets.*
54+
55+
### 2. Model Confusion Matrix
56+
The Confusion Matrix below illustrates the model's performance across all 5 classes. Note the high precision in the 'N' and 'Q' classes, and the improved sensitivity in the 'F' and 'V' categories.
57+
58+
![Confusion Matrix](results/confusion_matrix.png)
59+
*Figure 2: Confusion Matrix showcasing the identification of minority classes despite morphological similarities.*
60+
61+
---
62+
63+
## 📊 Performance Metrics (MIT-BIH Dataset)
64+
65+
Evaluation performed on the **MIT-BIH Arrhythmia Database**:
66+
67+
| Class | Description | Precision | Recall | F1-Score |
68+
|:---:|:---|:---:|:---:|:---:|
69+
| **N** | Normal Beat | 0.97 | 0.96 | **0.96** |
70+
| **S** | Supraventricular | 0.42 | 0.35 | **0.38** |
71+
| **V** | Ventricular Escape | 0.85 | 0.82 | **0.83** |
72+
| **F** | Fusion Beat | 0.30 | 0.58 | **0.40** |
73+
| **Q** | Unclassified / Unknown | 0.96 | 0.96 | **0.96** |
74+
75+
* **Overall Accuracy:** ~93%
76+
* **Macro Avg F1:** ~0.71 (Reflecting strong performance across all categories)
77+
78+
---
79+
80+
## 🛠️ Getting Started
81+
82+
### Prerequisites
83+
```bash
84+
pip install tensorflow numpy pandas matplotlib seaborn scikit-learn
85+
```
86+
87+
### Usage
88+
1. Place `mitbih_train.csv` and `mitbih_test.csv` in the `Datasets/` directory.
89+
2. Run the main script to train the model.
90+
3. The best model weights will be saved automatically as `RhythmAttention_best.keras`.
91+
92+
```python
93+
from model import build_rhythm_attention_model
94+
95+
# Load Architecture
96+
model = build_rhythm_attention_model(input_shape=(187, 1), num_classes=5)
97+
# Load Weights
98+
model.load_weights('models/RhythmAttention_best.keras')
99+
```
100+
101+
## 🏁 Conclusion
102+
103+
The **RhythmAttention** project successfully demonstrates that a **Hybrid CNN-Transformer architecture** is significantly more effective for ECG classification than traditional CNN or RNN models alone.
104+
105+
Key takeaways from this implementation include:
106+
* **Synergy of Features:** While the CNN layers successfully extract local morphology (P-QRS-T complexes), the Transformer layers provide a "global view" of the heart's rhythm, which is crucial for identifying arrhythmias.
107+
* **Robustness to Imbalance:** Through the use of **Focal Loss** and **Weighted Sampling**, the model effectively learns to identify rare but life-critical beats (like Fusion beats) without being overwhelmed by the majority "Normal" class.
108+
* **Stability:** The inclusion of **Spatial Dropout** and **Residual Connections** ensures that the model remains stable and avoids the common pitfall of "Exploding Gradients" often seen in complex Attention-based architectures.
109+
110+
Overall, RhythmAttention provides a reliable, scalable, and high-precision foundation for the next generation of automated cardiac diagnostic tools.
111+
112+
---
113+
114+
## 📝 Author
115+
**Sayyed Hossein Hosseini**
116+
*Deep Learning Researcher & Healthcare AI Enthusiast*

0 commit comments

Comments
 (0)