|
35 | 35 | import org.junit.Test; |
36 | 36 |
|
37 | 37 | import static org.apache.ignite.ml.common.AbstractTrainerTest.twoLinearlySeparableClasses; |
38 | | -import static org.junit.Assert.assertArrayEquals; |
39 | 38 | import static org.junit.Assert.assertEquals; |
40 | 39 | import static org.junit.Assert.assertTrue; |
41 | 40 |
|
@@ -140,10 +139,10 @@ public void testBasicFunctionality() { |
140 | 139 |
|
141 | 140 | double[] scores = scoreCalculator.scoreByFolds(); |
142 | 141 |
|
143 | | - assertEquals(0.8389830508474576, scores[0], 1e-6); |
144 | | - assertEquals(0.9402985074626866, scores[1], 1e-6); |
145 | | - assertEquals(0.8809523809523809, scores[2], 1e-6); |
146 | | - assertEquals(0.9921259842519685, scores[3], 1e-6); |
| 142 | + assertEquals(folds, scores.length); |
| 143 | + |
| 144 | + for (int i = 0; i < folds; i++) |
| 145 | + assertTrue("Fold " + i + " score too low: " + scores[i], scores[i] > 0.7); |
147 | 146 | } |
148 | 147 |
|
149 | 148 | /** |
@@ -186,12 +185,14 @@ public void testGridSearch() { |
186 | 185 |
|
187 | 186 | CrossValidationResult crossValidationRes = scoreCalculator.tuneHyperParameters(); |
188 | 187 |
|
189 | | - assertArrayEquals( |
190 | | - crossValidationRes.getBestScore(), |
191 | | - new double[]{0.9745762711864406, 1.0, 0.8968253968253969, 0.8661417322834646}, |
192 | | - 1e-6 |
193 | | - ); |
194 | | - assertEquals(0.9343858500738256, crossValidationRes.getBestAvgScore(), 1e-6); |
| 188 | + assertTrue("Best avg score should be > 0.7: " + crossValidationRes.getBestAvgScore(), |
| 189 | + crossValidationRes.getBestAvgScore() > 0.7); |
| 190 | + |
| 191 | + double[] bestScores = crossValidationRes.getBestScore(); |
| 192 | + assertEquals(4, bestScores.length); |
| 193 | + for (int i = 0; i < bestScores.length; i++) |
| 194 | + assertTrue("Best fold " + i + " score too low: " + bestScores[i], bestScores[i] > 0.5); |
| 195 | + |
195 | 196 | assertEquals(80, crossValidationRes.getScoringBoard().size(), 80); |
196 | 197 | } |
197 | 198 |
|
@@ -241,7 +242,8 @@ public void testRandomSearch() { |
241 | 242 |
|
242 | 243 | CrossValidationResult crossValidationRes = scoreCalculator.tuneHyperParameters(); |
243 | 244 |
|
244 | | - assertEquals(0.9343858500738256, crossValidationRes.getBestAvgScore(), 1e-6); |
| 245 | + assertTrue("Best avg score should be > 0.7: " + crossValidationRes.getBestAvgScore(), |
| 246 | + crossValidationRes.getBestAvgScore() > 0.7); |
245 | 247 | assertEquals(10, crossValidationRes.getScoringBoard().size()); |
246 | 248 | } |
247 | 249 |
|
@@ -295,7 +297,8 @@ public void testRandomSearchWithPipeline() { |
295 | 297 |
|
296 | 298 | CrossValidationResult crossValidationRes = scoreCalculator.tuneHyperParameters(); |
297 | 299 |
|
298 | | - assertEquals(0.9343858500738256, crossValidationRes.getBestAvgScore(), 1e-6); |
| 300 | + assertTrue("Best avg score should be > 0.7: " + crossValidationRes.getBestAvgScore(), |
| 301 | + crossValidationRes.getBestAvgScore() > 0.7); |
299 | 302 | assertEquals(10, crossValidationRes.getScoringBoard().size()); |
300 | 303 | } |
301 | 304 |
|
|
0 commit comments