Skip to content

Commit cc8989a

Browse files
Weighted Sampling Preparation
1 parent af28292 commit cc8989a

1 file changed

Lines changed: 45 additions & 0 deletions

File tree

RhythmAttention_Hybrid_CNN_Transformer_Architecture_for_Arrhythmia_Classification.ipynb

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,51 @@
151151
"print(f\"X_train shape: {X_train.shape}\") # Expected: (87554, 187, 1)\n",
152152
"print(f\"X_test shape: {X_test.shape}\") # Expected: (21892, 187, 1)"
153153
]
154+
},
155+
{
156+
"cell_type": "markdown",
157+
"metadata": {},
158+
"source": [
159+
"## **Weighted Sampling Preparation**"
160+
]
161+
},
162+
{
163+
"cell_type": "code",
164+
"execution_count": 4,
165+
"metadata": {},
166+
"outputs": [
167+
{
168+
"name": "stdout",
169+
"output_type": "stream",
170+
"text": [
171+
"Class Weights calculated for balanced learning:\n",
172+
" Class 0: 0.2416\n",
173+
" Class 1: 7.8771\n",
174+
" Class 2: 3.0254\n",
175+
" Class 3: 27.3179\n",
176+
" Class 4: 2.7229\n"
177+
]
178+
}
179+
],
180+
"source": [
181+
"# 4. Calculate class weights to handle extreme imbalance\n",
182+
"# 'balanced' mode automatically assigns higher weights to minority classes\n",
183+
"weights = class_weight.compute_class_weight(\n",
184+
" class_weight='balanced',\n",
185+
" classes=np.unique(y_train),\n",
186+
" y=y_train\n",
187+
")\n",
188+
"class_weights_dict = dict(enumerate(weights))\n",
189+
"\n",
190+
"# 5. Map weights to every individual sample for the sampler\n",
191+
"sample_weights = np.array([class_weights_dict[cls] for cls in y_train])\n",
192+
"# Normalize weights into probabilities\n",
193+
"sample_probabilities = sample_weights / np.sum(sample_weights)\n",
194+
"\n",
195+
"print(\"Class Weights calculated for balanced learning:\")\n",
196+
"for cls, w in class_weights_dict.items():\n",
197+
" print(f\" Class {cls}: {w:.4f}\")"
198+
]
154199
}
155200
],
156201
"metadata": {

0 commit comments

Comments
 (0)