|
221 | 221 | " factor = np.random.uniform(factor_range[0], factor_range[1])\n", |
222 | 222 | " return signal * factor" |
223 | 223 | ] |
| 224 | + }, |
| 225 | + { |
| 226 | + "cell_type": "markdown", |
| 227 | + "metadata": {}, |
| 228 | + "source": [ |
| 229 | + "## **Smart Data Generator**" |
| 230 | + ] |
| 231 | + }, |
| 232 | + { |
| 233 | + "cell_type": "code", |
| 234 | + "execution_count": 6, |
| 235 | + "metadata": {}, |
| 236 | + "outputs": [ |
| 237 | + { |
| 238 | + "name": "stdout", |
| 239 | + "output_type": "stream", |
| 240 | + "text": [ |
| 241 | + "Batch data shape: (32, 187, 1)\n", |
| 242 | + "Batch labels shape: (32, 5)\n" |
| 243 | + ] |
| 244 | + } |
| 245 | + ], |
| 246 | + "source": [ |
| 247 | + "# 7. Create a Python Generator for training\n", |
| 248 | + "def rhythm_attention_generator(X, y, batch_size, augment=True):\n", |
| 249 | + " \"\"\"\n", |
| 250 | + " Yields balanced batches using Weighted Sampling.\n", |
| 251 | + " Applies Online Augmentation to minority classes for every iteration.\n", |
| 252 | + " \"\"\"\n", |
| 253 | + " num_samples = len(X)\n", |
| 254 | + " indices = np.arange(num_samples)\n", |
| 255 | + " \n", |
| 256 | + " while True:\n", |
| 257 | + " # Perform Weighted Random Sampling for the batch\n", |
| 258 | + " batch_indices = np.random.choice(indices, size=batch_size, p=sample_probabilities)\n", |
| 259 | + " \n", |
| 260 | + " X_batch = X[batch_indices].copy()\n", |
| 261 | + " y_batch = y[batch_indices]\n", |
| 262 | + " \n", |
| 263 | + " if augment:\n", |
| 264 | + " for i in range(batch_size):\n", |
| 265 | + " # Apply heavier augmentation to minority classes (1, 2, 3, 4)\n", |
| 266 | + " if y_batch[i] != 0:\n", |
| 267 | + " X_batch[i] = apply_amplitude_scaling(add_gaussian_noise(X_batch[i]))\n", |
| 268 | + " else:\n", |
| 269 | + " # Apply very light noise to majority class to improve robustness\n", |
| 270 | + " if np.random.rand() > 0.7:\n", |
| 271 | + " X_batch[i] = add_gaussian_noise(X_batch[i], noise_level=0.002)\n", |
| 272 | + " \n", |
| 273 | + " # Convert labels to One-Hot Encoding for the Softmax output\n", |
| 274 | + " y_batch_onehot = tf.keras.utils.to_categorical(y_batch, num_classes=5)\n", |
| 275 | + " \n", |
| 276 | + " yield X_batch, y_batch_onehot\n", |
| 277 | + "\n", |
| 278 | + "# 8. Initialize Training and Validation Generators\n", |
| 279 | + "train_gen = rhythm_attention_generator(X_train, y_train, batch_size=32, augment=True)\n", |
| 280 | + "\n", |
| 281 | + "# Note: Validation/Test data should NOT be weighted or augmented\n", |
| 282 | + "# We convert it to One-Hot once for evaluation\n", |
| 283 | + "y_test_onehot = tf.keras.utils.to_categorical(y_test, num_classes=5)\n", |
| 284 | + "\n", |
| 285 | + "# Verify the generator output\n", |
| 286 | + "x_batch, y_batch = next(train_gen)\n", |
| 287 | + "print(f\"Batch data shape: {x_batch.shape}\")\n", |
| 288 | + "print(f\"Batch labels shape: {y_batch.shape}\")" |
| 289 | + ] |
224 | 290 | } |
225 | 291 | ], |
226 | 292 | "metadata": { |
|
0 commit comments