|
89 | 89 | }, |
90 | 90 | { |
91 | 91 | "cell_type": "code", |
92 | | - "execution_count": 7, |
| 92 | + "execution_count": 11, |
93 | 93 | "metadata": {}, |
94 | 94 | "outputs": [], |
95 | 95 | "source": [ |
| 96 | + "import os\n", |
96 | 97 | "import pandas as pd\n", |
97 | 98 | "import numpy as np\n", |
98 | 99 | "import tensorflow as tf\n", |
|
518 | 519 | "model = build_rhythm_attention_model()\n", |
519 | 520 | "model.summary()" |
520 | 521 | ] |
| 522 | + }, |
| 523 | + { |
| 524 | + "cell_type": "markdown", |
| 525 | + "metadata": {}, |
| 526 | + "source": [ |
| 527 | + "## **Callbacks**" |
| 528 | + ] |
| 529 | + }, |
| 530 | + { |
| 531 | + "cell_type": "code", |
| 532 | + "execution_count": 12, |
| 533 | + "metadata": {}, |
| 534 | + "outputs": [], |
| 535 | + "source": [ |
| 536 | + "# Create a folder to save the model if it doesn't exist\n", |
| 537 | + "if not os.path.exists('models'):\n", |
| 538 | + " os.makedirs('models')\n", |
| 539 | + "\n", |
| 540 | + "# Updated Callbacks\n", |
| 541 | + "callbacks = [\n", |
| 542 | + " # 1. Early Stopping: Stop training when val_loss stops improving\n", |
| 543 | + " tf.keras.callbacks.EarlyStopping(\n", |
| 544 | + " monitor='val_loss', \n", |
| 545 | + " patience=12, \n", |
| 546 | + " restore_best_weights=True,\n", |
| 547 | + " verbose=1\n", |
| 548 | + " ),\n", |
| 549 | + " \n", |
| 550 | + " # 2. Model Checkpoint: Save the best model based on validation loss\n", |
| 551 | + " tf.keras.callbacks.ModelCheckpoint(\n", |
| 552 | + " filepath='models/RhythmAttention_best.keras', \n", |
| 553 | + " monitor='val_loss', \n", |
| 554 | + " save_best_only=True,\n", |
| 555 | + " verbose=1\n", |
| 556 | + " ),\n", |
| 557 | + " \n", |
| 558 | + " # 3. Learning Rate Scheduler: Reduce LR when learning plateaus\n", |
| 559 | + " tf.keras.callbacks.ReduceLROnPlateau(\n", |
| 560 | + " monitor='val_loss', \n", |
| 561 | + " factor=0.5, \n", |
| 562 | + " patience=6, \n", |
| 563 | + " min_lr=0.00001,\n", |
| 564 | + " verbose=1\n", |
| 565 | + " )\n", |
| 566 | + "]" |
| 567 | + ] |
521 | 568 | } |
522 | 569 | ], |
523 | 570 | "metadata": { |
|
0 commit comments