-
-
Notifications
You must be signed in to change notification settings - Fork 270
Expand file tree
/
Copy pathtest_split.py
More file actions
87 lines (73 loc) · 2.91 KB
/
test_split.py
File metadata and controls
87 lines (73 loc) · 2.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# License: BSD 3-Clause
from __future__ import annotations
import contextlib
import inspect
import os
import shutil
from pathlib import Path
import numpy as np
import pytest
from openml import OpenMLSplit
from openml.testing import TestBase
class OpenMLSplitTest(TestBase):
# Splitting not helpful, these test's don't rely on the server and take less
# than 5 seconds + rebuilding the test would potentially be costly
def setUp(self):
super().setUp()
__file__ = inspect.getfile(OpenMLSplitTest)
self.directory = os.path.dirname(__file__)
# This is for dataset
self.arff_filepath = (
Path(self.directory).parent
/ "files"
/ "org"
/ "openml"
/ "test"
/ "tasks"
/ "1882"
/ "datasplits.arff"
)
self.pd_filename = self.arff_filepath.with_suffix(".pkl.py3")
def tearDown(self):
self._temp_dir.cleanup()
def test_eq(self):
split = OpenMLSplit._from_arff_file(self.arff_filepath)
split2 = OpenMLSplit._from_arff_file(self.arff_filepath)
split2.name = "a"
assert split != split2
split2 = OpenMLSplit._from_arff_file(self.arff_filepath)
split2.description = "a"
assert split != split2
split2 = OpenMLSplit._from_arff_file(self.arff_filepath)
split2.split[10] = {}
assert split != split2
split2 = OpenMLSplit._from_arff_file(self.arff_filepath)
split2.split[0][10] = {}
assert split != split2
def test_from_arff_file(self):
split = OpenMLSplit._from_arff_file(self.arff_filepath)
assert isinstance(split.split, dict)
assert isinstance(split.split[0], dict)
assert isinstance(split.split[0][0], dict)
assert isinstance(split.split[0][0][0][0], np.ndarray)
assert isinstance(split.split[0][0][0].train, np.ndarray)
assert isinstance(split.split[0][0][0].train, np.ndarray)
assert isinstance(split.split[0][0][0][1], np.ndarray)
assert isinstance(split.split[0][0][0].test, np.ndarray)
assert isinstance(split.split[0][0][0].test, np.ndarray)
for i in range(10):
for j in range(10):
assert split.split[i][j][0].train.shape[0] >= 808
assert split.split[i][j][0].test.shape[0] >= 89
assert (
split.split[i][j][0].train.shape[0] + split.split[i][j][0].test.shape[0] == 898
)
def test_get_split(self):
split = OpenMLSplit._from_arff_file(self.arff_filepath)
train_split, test_split = split.get(fold=5, repeat=2)
assert train_split.shape[0] == 808
assert test_split.shape[0] == 90
with pytest.raises(ValueError, match="Repeat 10 not known"):
split.get(10, 2)
with pytest.raises(ValueError, match="Fold 10 not known"):
split.get(2, 10)