Skip to content

Commit c8b179f

Browse files
Merge pull request #337 from KernelTuner/fix_issue_333
This fixes issue #333 on backwards compatibility with the old restrictions function
2 parents 1f63c7c + 38700af commit c8b179f

2 files changed

Lines changed: 43 additions & 4 deletions

File tree

kernel_tuner/searchspace.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from warnings import warn
88
from copy import deepcopy
99
from collections import defaultdict, deque
10+
from inspect import signature
1011

1112
import numpy as np
1213
from scipy.stats.qmc import LatinHypercube
@@ -499,6 +500,13 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver:
499500
def __add_restrictions(self, parameter_space: Problem) -> Problem:
500501
"""Add the user-specified restrictions as constraints on the parameter space."""
501502
restrictions = deepcopy(self.restrictions)
503+
# differentiate between old style monolithic with single 'p' argument and newer *args style
504+
if (len(restrictions) == 1
505+
and not isinstance(restrictions[0], (Constraint, FunctionConstraint, str))
506+
and callable(restrictions[0])
507+
and len(signature(restrictions[0]).parameters) == 1
508+
and len(self.param_names) > 1):
509+
restrictions = restrictions[0]
502510
if isinstance(restrictions, list):
503511
for restriction in restrictions:
504512
required_params = self.param_names
@@ -508,10 +516,6 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
508516
required_params = restriction[1]
509517
restriction = restriction[0]
510518
if callable(restriction) and not isinstance(restriction, Constraint):
511-
# def restrictions_wrapper(*args):
512-
# return check_instance_restrictions(restriction, dict(zip(self.param_names, args)), False)
513-
# print(restriction, isinstance(restriction, Constraint))
514-
# restriction = FunctionConstraint(restrictions_wrapper)
515519
restriction = FunctionConstraint(restriction, required_params)
516520

517521
# add as a Constraint
@@ -533,6 +537,7 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem:
533537
elif callable(restrictions):
534538

535539
def restrictions_wrapper(*args):
540+
"""Wrap old-style monolithic restrictions to work with multiple arguments."""
536541
return check_instance_restrictions(restrictions, dict(zip(self.param_names, args)), False)
537542

538543
parameter_space.addConstraint(FunctionConstraint(restrictions_wrapper), self.param_names)

test/test_searchspace.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,3 +642,37 @@ def test_full_searchspace(compare_against_bruteforce=False):
642642
compare_two_searchspace_objects(searchspace, searchspace_bruteforce)
643643
else:
644644
assert searchspace.size == len(searchspace.list) == 349853
645+
646+
def test_restriction_backwards_compatibility():
647+
"""Test whether the backwards compatibility code for restrictions (list of strings) works as expected."""
648+
# create a searchspace with mixed parameter types
649+
max_threads = 1024
650+
tune_params = dict()
651+
tune_params["N_PER_BLOCK"] = [32, 64, 128, 256, 512, 1024]
652+
tune_params["M_PER_BLOCK"] = [32, 64, 128, 256, 512, 1024]
653+
tune_params["block_size_y"] = [1, 2, 4, 8, 16, 32]
654+
tune_params["block_size_z"] = [1, 2, 4, 8, 16, 32]
655+
656+
# old style monolithic restriction function
657+
def restrict(p):
658+
n_global_per_warp = int(p["N_PER_BLOCK"] // p["block_size_y"])
659+
m_global_per_warp = int(p["M_PER_BLOCK"] // p["block_size_z"])
660+
if n_global_per_warp == 0 or m_global_per_warp == 0:
661+
return False
662+
663+
searchspace_callable = Searchspace(tune_params, restrict, max_threads)
664+
665+
def restrict_args(N_PER_BLOCK, M_PER_BLOCK, block_size_y, block_size_z):
666+
n_global_per_warp = int(N_PER_BLOCK // block_size_y)
667+
m_global_per_warp = int(M_PER_BLOCK // block_size_z)
668+
if n_global_per_warp == 0 or m_global_per_warp == 0:
669+
return False
670+
671+
# args-style restriction
672+
searchspace_str = Searchspace(tune_params, restrict_args, max_threads)
673+
674+
# check the size
675+
assert searchspace_str.size == searchspace_callable.size
676+
677+
# check that both searchspaces are identical in outcome
678+
compare_two_searchspace_objects(searchspace_str, searchspace_callable)

0 commit comments

Comments
 (0)