Skip to content

Commit f193bc5

Browse files
Smart Data Generator
1 parent e8dff30 commit f193bc5

1 file changed

Lines changed: 66 additions & 0 deletions

File tree

RhythmAttention_Hybrid_CNN_Transformer_Architecture_for_Arrhythmia_Classification.ipynb

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,72 @@
221221
" factor = np.random.uniform(factor_range[0], factor_range[1])\n",
222222
" return signal * factor"
223223
]
224+
},
225+
{
226+
"cell_type": "markdown",
227+
"metadata": {},
228+
"source": [
229+
"## **Smart Data Generator**"
230+
]
231+
},
232+
{
233+
"cell_type": "code",
234+
"execution_count": 6,
235+
"metadata": {},
236+
"outputs": [
237+
{
238+
"name": "stdout",
239+
"output_type": "stream",
240+
"text": [
241+
"Batch data shape: (32, 187, 1)\n",
242+
"Batch labels shape: (32, 5)\n"
243+
]
244+
}
245+
],
246+
"source": [
247+
"# 7. Create a Python Generator for training\n",
248+
"def rhythm_attention_generator(X, y, batch_size, augment=True):\n",
249+
" \"\"\"\n",
250+
" Yields balanced batches using Weighted Sampling.\n",
251+
" Applies Online Augmentation to minority classes for every iteration.\n",
252+
" \"\"\"\n",
253+
" num_samples = len(X)\n",
254+
" indices = np.arange(num_samples)\n",
255+
" \n",
256+
" while True:\n",
257+
" # Perform Weighted Random Sampling for the batch\n",
258+
" batch_indices = np.random.choice(indices, size=batch_size, p=sample_probabilities)\n",
259+
" \n",
260+
" X_batch = X[batch_indices].copy()\n",
261+
" y_batch = y[batch_indices]\n",
262+
" \n",
263+
" if augment:\n",
264+
" for i in range(batch_size):\n",
265+
" # Apply heavier augmentation to minority classes (1, 2, 3, 4)\n",
266+
" if y_batch[i] != 0:\n",
267+
" X_batch[i] = apply_amplitude_scaling(add_gaussian_noise(X_batch[i]))\n",
268+
" else:\n",
269+
" # Apply very light noise to majority class to improve robustness\n",
270+
" if np.random.rand() > 0.7:\n",
271+
" X_batch[i] = add_gaussian_noise(X_batch[i], noise_level=0.002)\n",
272+
" \n",
273+
" # Convert labels to One-Hot Encoding for the Softmax output\n",
274+
" y_batch_onehot = tf.keras.utils.to_categorical(y_batch, num_classes=5)\n",
275+
" \n",
276+
" yield X_batch, y_batch_onehot\n",
277+
"\n",
278+
"# 8. Initialize Training and Validation Generators\n",
279+
"train_gen = rhythm_attention_generator(X_train, y_train, batch_size=32, augment=True)\n",
280+
"\n",
281+
"# Note: Validation/Test data should NOT be weighted or augmented\n",
282+
"# We convert it to One-Hot once for evaluation\n",
283+
"y_test_onehot = tf.keras.utils.to_categorical(y_test, num_classes=5)\n",
284+
"\n",
285+
"# Verify the generator output\n",
286+
"x_batch, y_batch = next(train_gen)\n",
287+
"print(f\"Batch data shape: {x_batch.shape}\")\n",
288+
"print(f\"Batch labels shape: {y_batch.shape}\")"
289+
]
224290
}
225291
],
226292
"metadata": {

0 commit comments

Comments
 (0)