Skip to content

Commit 3ec656c

Browse files
committed
FIX validation scripts
* Consider training set for computing SBS
1 parent a18a00b commit 3ec656c

3 files changed

Lines changed: 36 additions & 12 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ In `validation/`, we provide a script to validate your results files
3535
on known test data. Please note that all files of the test scenario has to be provided for this script.
3636
Example call:
3737

38-
```python validation/validate_cli.py --result_fn results.json --test_as example_files/SAT11-INDU-TRAIN/```
38+
```python validation/validate_cli.py --result_fn results.json --test_as example_files/SAT11-INDU-TEST/ --train_as example_files/SAT11-INDU-TRAIN/```
3939

4040
Add "." to your PYTHONPATH to avoid import errors, e.g.,
4141
```export PYTHONPATH=.//:$PYTHONPATH```

validation/validate.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(self, runtime_cutoff:int,
3535

3636
self.runtime_cutoff = runtime_cutoff
3737
self.maximize = maximize
38+
self.worse_than_sbs = 0 # int counter
3839

3940
self.logger = logging.getLogger("Stats")
4041

@@ -93,16 +94,19 @@ def show(self, remove_unsolvable: bool=True):
9394

9495
if self.maximize:
9596
self.logger.info("Gap closed: %.4f" %((par10 - self.sbs_par10) / (self.oracle_par10 - self.sbs_par10)))
97+
self.logger.info("Gap remaining: %.4f" %((self.oracle_par10 - par10) / (self.oracle_par10 - self.sbs_par10)))
9698
else:
9799
self.logger.info("Gap closed: %.4f" %((self.sbs_par10 - par10) / (self.sbs_par10 - self.oracle_par10)))
100+
self.logger.info("Gap remaining: %.4f" %((par10 - self.oracle_par10) / (self.sbs_par10 - self.oracle_par10)))
98101

99102
class Validator(object):
100103

101104
def __init__(self):
102105
''' Constructor '''
103106
self.logger = logging.getLogger("Validation")
104107

105-
def validate_runtime(self, schedules: dict, test_scenario: ASlibScenario):
108+
def validate_runtime(self, schedules: dict, test_scenario: ASlibScenario,
109+
train_scenario: ASlibScenario):
106110
'''
107111
validate selected schedules on test instances for runtime
108112
@@ -112,6 +116,8 @@ def validate_runtime(self, schedules: dict, test_scenario: ASlibScenario):
112116
algorithm schedules per instance
113117
test_scenario: ASlibScenario
114118
ASlib scenario with test instances
119+
train_scenario: ASlibScenario
120+
ASlib scenario with test instances -- required for SBS
115121
'''
116122
if test_scenario.performance_type[0] != "runtime":
117123
raise ValueError("Cannot validate non-runtime scenario with runtime validation method")
@@ -133,7 +139,8 @@ def validate_runtime(self, schedules: dict, test_scenario: ASlibScenario):
133139
sys.exit(1)
134140

135141
stat.oracle_par10 = test_scenario.performance_data.min(axis=1).sum()
136-
stat.sbs_par10 = test_scenario.performance_data.sum(axis=0).min()
142+
sbs = train_scenario.performance_data.sum(axis=0).argmin()
143+
stat.sbs_par10 = test_scenario.performance_data.sum(axis=0)[sbs]
137144

138145
for inst, schedule in schedules.items():
139146
self.logger.debug("Validate: %s on %s" % (schedule, inst))
@@ -195,7 +202,8 @@ def validate_runtime(self, schedules: dict, test_scenario: ASlibScenario):
195202

196203
return stat
197204

198-
def validate_quality(self, schedules: dict, test_scenario: ASlibScenario):
205+
def validate_quality(self, schedules: dict, test_scenario: ASlibScenario,
206+
train_scenario: ASlibScenario):
199207
'''
200208
validate selected schedules on test instances for solution quality
201209
@@ -205,6 +213,8 @@ def validate_quality(self, schedules: dict, test_scenario: ASlibScenario):
205213
algorithm schedules per instance
206214
test_scenario: ASlibScenario
207215
ASlib scenario with test instances
216+
train_scenario: ASlibScenario
217+
ASlib scenario with test instances -- required for SBS
208218
'''
209219
if test_scenario.performance_type[0] != "solution_quality":
210220
raise ValueError("Cannot validate non-solution_quality scenario with solution_quality validation method")
@@ -213,6 +223,7 @@ def validate_quality(self, schedules: dict, test_scenario: ASlibScenario):
213223

214224
if test_scenario.maximize[0]:
215225
test_scenario.performance_data *= -1
226+
train_scenario.performance_data *= -1
216227
self.logger.debug("Removing *-1 in performance data because of maximization")
217228

218229
stat = Stats(runtime_cutoff=None, maximize=test_scenario.maximize[0])
@@ -224,10 +235,12 @@ def validate_quality(self, schedules: dict, test_scenario: ASlibScenario):
224235

225236
if test_scenario.maximize[0]:
226237
stat.oracle_par10 = test_scenario.performance_data.max(axis=1).sum()
227-
stat.sbs_par10 = test_scenario.performance_data.sum(axis=0).max()
238+
sbs = train_scenario.performance_data.sum(axis=0).argmax()
239+
stat.sbs_par10 = test_scenario.performance_data.sum(axis=0)[sbs]
228240
else:
229241
stat.oracle_par10 = test_scenario.performance_data.min(axis=1).sum()
230-
stat.sbs_par10 = test_scenario.performance_data.sum(axis=0).min()
242+
sbs = train_scenario.performance_data.sum(axis=0).argmin()
243+
stat.sbs_par10 = test_scenario.performance_data.sum(axis=0)[sbs]
231244

232245
for inst, schedule in schedules.items():
233246

@@ -252,6 +265,14 @@ def validate_quality(self, schedules: dict, test_scenario: ASlibScenario):
252265

253266
stat.par1 += perf
254267
stat.solved += 1
268+
if test_scenario.maximize[0]:
269+
if perf < test_scenario.performance_data[sbs][inst]:
270+
stat.worse_than_sbs += 1
271+
print("%s(%.3f) vs %s (%.3f)" %(selected_algo, perf, sbs, test_scenario.performance_data[sbs][inst]))
272+
else:
273+
if perf > test_scenario.performance_data[sbs][inst]:
274+
stat.worse_than_sbs += 1
275+
print("%s(%.3f) vs %s (%.3f)" %(selected_algo, perf, sbs, test_scenario.performance_data[sbs][inst]))
255276

256277
stat.show(remove_unsolvable=False)
257278

validation/validate_cli.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,23 @@
1414
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
1515
parser.add_argument("--result_fn", help="Result json file with predictions for each test instances")
1616
parser.add_argument("--test_as", help="Directory with *all* test data in ASlib format")
17+
parser.add_argument("--train_as", help="Directory with *all* train data in ASlib format")
1718

1819
args_ = parser.parse_args()
1920

20-
#read scenario
21-
scenario = ASlibScenario()
22-
scenario.read_scenario(dn=args_.test_as)
21+
#read scenarios
22+
test_scenario = ASlibScenario()
23+
test_scenario.read_scenario(dn=args_.test_as)
24+
train_scenario = ASlibScenario()
25+
train_scenario.read_scenario(dn=args_.train_as)
2326

2427
# read result file
2528
with open(args_.result_fn) as fp:
2629
schedules = json.load(fp)
2730

2831
validator = Validator()
2932

30-
if scenario.performance_type[0] == "runtime":
31-
validator.validate_runtime(schedules=schedules, test_scenario=scenario)
33+
if test_scenario.performance_type[0] == "runtime":
34+
validator.validate_runtime(schedules=schedules, test_scenario=test_scenario, train_scenario=train_scenario)
3235
else:
33-
validator.validate_quality(schedules=schedules, test_scenario=scenario)
36+
validator.validate_quality(schedules=schedules, test_scenario=test_scenario, train_scenario=train_scenario)

0 commit comments

Comments
 (0)