From 2736c3fafdf7fd5a536e185f4da4b2d9db676a58 Mon Sep 17 00:00:00 2001 From: stijn Date: Wed, 14 Jan 2026 14:12:33 +0100 Subject: [PATCH 1/3] Fix issue where restriction lambdas cannot contain captured variables. --- kernel_tuner/util.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/kernel_tuner/util.py b/kernel_tuner/util.py index 2d9e3f1b..386c5ce2 100644 --- a/kernel_tuner/util.py +++ b/kernel_tuner/util.py @@ -1039,6 +1039,11 @@ def get_all_lambda_asts(func): return res +class InvalidLambdaError(Exception): + def __str__(self): + return "lambda could not be parsed by Kernel Tuner" + + class ConstraintLambdaTransformer(ast.NodeTransformer): """Replaces any `NAME['string']` subscript with just `'string'`, if `NAME` matches the lambda argument name. @@ -1046,6 +1051,13 @@ class ConstraintLambdaTransformer(ast.NodeTransformer): def __init__(self, dict_arg_name): self.dict_arg_name = dict_arg_name + def visit_Name(self, node): + # If we find a Name node that is not part of a Subscript expression, then + # we throw an exception. This happens when a lambda contains a captured + # variable or calls a function. In these cases, we cannot transform the + # lambda into a string so we just exit the ast transformer. + raise InvalidLambdaError() + def visit_Subscript(self, node): # We only replace subscript expressions of the form ['some_string'] if (isinstance(node.value, ast.Name) @@ -1062,19 +1074,19 @@ def unparse_constraint_lambda(lambda_ast): Returns string body of the rewritten lambda function """ args = lambda_ast.args - body = lambda_ast.args # Kernel Tuner only allows constraint lambdas with a single argument - arg = args.args[0].arg + if len(args.args) != 1: + raise InvalidLambdaError() + + first_arg = args.args[0].arg # Create transformer that replaces accesses to tunable parameter dict # with simply the name of the tunable parameter - transformer = ConstraintLambdaTransformer(arg) + transformer = ConstraintLambdaTransformer(first_arg) new_lambda_ast = transformer.visit(lambda_ast) - rewritten_lambda_body_as_string = ast.unparse(new_lambda_ast.body).strip() - - return rewritten_lambda_body_as_string + return ast.unparse(new_lambda_ast.body).strip() def convert_constraint_lambdas(restrictions): @@ -1083,16 +1095,13 @@ def convert_constraint_lambdas(restrictions): for c in restrictions: if isinstance(c, (str, Constraint)): res.append(c) - if callable(c) and not isinstance(c, Constraint): + elif callable(c): try: lambda_asts = get_all_lambda_asts(c) - except ValueError: + res += [unparse_constraint_lambda(lambda_ast) for lambda_ast in lambda_asts] + except (InvalidLambdaError, ValueError): res.append(c) # it's just a plain function, not a lambda - continue - for lambda_ast in lambda_asts: - new_c = unparse_constraint_lambda(lambda_ast) - res.append(new_c) result = list(set(res)) if not len(result) == len(restrictions): From 3b33eaca88e3a8b212af019ef0fea969d153dbf7 Mon Sep 17 00:00:00 2001 From: stijn Date: Wed, 14 Jan 2026 14:26:38 +0100 Subject: [PATCH 2/3] Move functionality related to restrictions to a seperate file --- kernel_tuner/restrictions.py | 391 +++++++++++++++++++++++++++++++++ kernel_tuner/util.py | 410 +---------------------------------- 2 files changed, 394 insertions(+), 407 deletions(-) create mode 100644 kernel_tuner/restrictions.py diff --git a/kernel_tuner/restrictions.py b/kernel_tuner/restrictions.py new file mode 100644 index 00000000..e74fb944 --- /dev/null +++ b/kernel_tuner/restrictions.py @@ -0,0 +1,391 @@ +"""Module for parsing and evaluating restrictions.""" + +from inspect import getsource +from types import FunctionType +from typing import Union +import ast +import logging +import numpy as np +import re +import textwrap + +from constraint import ( + AllDifferentConstraint, + AllEqualConstraint, + Constraint, + ExactSumConstraint, + FunctionConstraint, + InSetConstraint, + MaxProdConstraint, + MaxSumConstraint, + MinProdConstraint, + MinSumConstraint, + NotInSetConstraint, + SomeInSetConstraint, + SomeNotInSetConstraint, +) + + +def check_restriction(restrict, params: dict) -> bool: + """Check whether a configuration meets a search space restriction.""" + # if it's a python-constraint, convert to function and execute + if isinstance(restrict, Constraint): + restrict = convert_constraint_restriction(restrict) + return restrict(list(params.values())) + # if it's a string, fill in the parameters and evaluate + elif isinstance(restrict, str): + return eval(replace_param_occurrences(restrict, params)) + # if it's a function, call it + elif callable(restrict): + return restrict(**params) + # if it's a tuple, use only the parameters in the second argument to call the restriction + elif ( + isinstance(restrict, tuple) + and (len(restrict) == 2 or len(restrict) == 3) + and callable(restrict[0]) + and isinstance(restrict[1], (list, tuple)) + ): + # unpack the tuple + if len(restrict) == 2: + restrict, selected_params = restrict + else: + restrict, selected_params, source = restrict + # look up the selected parameters and their value + selected_params = dict((key, params[key]) for key in selected_params) + # call the restriction + if isinstance(restrict, Constraint): + restrict = convert_constraint_restriction(restrict) + return restrict(list(selected_params.values())) + else: + return restrict(**selected_params) + # otherwise, raise an error + else: + raise ValueError(f"Unknown restriction type {type(restrict)} ({restrict})") + + +def check_restrictions(restrictions, params: dict, verbose: bool) -> bool: + """Check whether a configuration meets the search space restrictions.""" + if callable(restrictions): + valid = restrictions(params) + if not valid and verbose is True: + print(f"skipping config {get_instance_string(params)}, reason: config fails restriction") + return valid + valid = True + for restrict in restrictions: + # Check the type of each restriction and validate accordingly. Re-implement as a switch when Python >= 3.10. + try: + valid = check_restriction(restrict, params) + if not valid: + break + except ZeroDivisionError: + logging.debug(f"Restriction {restrict} with configuration {get_instance_string(params)} divides by zero.") + if not valid and verbose is True: + print(f"skipping config {get_instance_string(params)}, reason: config fails restriction {restrict}") + return valid + + +def convert_constraint_restriction(restrict: Constraint): + """Convert the python-constraint to a function for backwards compatibility.""" + if isinstance(restrict, FunctionConstraint): + + def f_restrict(p): + return restrict._func(*p) + + elif isinstance(restrict, AllDifferentConstraint): + + def f_restrict(p): + return len(set(p)) == len(p) + + elif isinstance(restrict, AllEqualConstraint): + + def f_restrict(p): + return all(x == p[0] for x in p) + + elif isinstance(restrict, MaxProdConstraint): + + def f_restrict(p): + return np.prod(p) <= restrict._maxprod + + elif isinstance(restrict, MinProdConstraint): + + def f_restrict(p): + return np.prod(p) >= restrict._minprod + + elif isinstance(restrict, MaxSumConstraint): + + def f_restrict(p): + return sum(p) <= restrict._maxsum + + elif isinstance(restrict, ExactSumConstraint): + + def f_restrict(p): + return sum(p) == restrict._exactsum + + elif isinstance(restrict, MinSumConstraint): + + def f_restrict(p): + return sum(p) >= restrict._minsum + + elif isinstance(restrict, (InSetConstraint, NotInSetConstraint, SomeInSetConstraint, SomeNotInSetConstraint)): + raise NotImplementedError( + f"Restriction of the type {type(restrict)} is explicitly not supported in backwards compatibility mode, because the behaviour is too complex. Please rewrite this constraint to a function to use it with this algorithm." + ) + else: + raise TypeError(f"Unrecognized restriction {restrict}") + return f_restrict + + +def parse_restrictions( + restrictions: list[str], tune_params: dict, monolithic=False, format=None +) -> list[tuple[Union[Constraint, str], list[str]]]: + """Parses restrictions from a list of strings into compilable functions and constraints, or a single compilable function (if monolithic is True). Returns a list of tuples of (strings or constraints) and parameters.""" + # rewrite the restrictions so variables are singled out + regex_match_variable = r"([a-zA-Z_$][a-zA-Z_$0-9]*)" + + def replace_params(match_object): + key = match_object.group(1) + if key in tune_params and format != "pyatf": + param = str(key) + return "params[params_index['" + param + "']]" + else: + return key + + def replace_params_split(match_object): + # careful: has side-effect of adding to set `params_used` + key = match_object.group(1) + if key in tune_params: + param = str(key) + params_used.add(param) + return param + else: + return key + + # remove functionally duplicate restrictions (preserves order and whitespace) + if all(isinstance(r, str) for r in restrictions): + # clean the restriction strings to functional equivalence + restrictions_cleaned = [r.replace(" ", "") for r in restrictions] + restrictions_cleaned_unique = list(dict.fromkeys(restrictions_cleaned)) # dict preserves order + # get the indices of the unique restrictions, use these to build a new list of restrictions + restrictions_unique_indices = [restrictions_cleaned.index(r) for r in restrictions_cleaned_unique] + restrictions = [restrictions[i] for i in restrictions_unique_indices] + + # create the parsed restrictions + if monolithic is False: + # split into functions that only take their relevant parameters + parsed_restrictions = list() + for res in restrictions: + params_used: set[str] = set() + parsed_restriction = re.sub(regex_match_variable, replace_params_split, res).strip() + params_used = list(params_used) + finalized_constraint = None + # we must turn it into a general function + if format is not None and format.lower() == "pyatf": + finalized_constraint = parsed_restriction + else: + finalized_constraint = f"def r({', '.join(params_used)}): return {parsed_restriction} \n" + parsed_restrictions.append((finalized_constraint, params_used)) + + # if pyATF, restrictions that are set on the same parameter must be combined into one + if format is not None and format.lower() == "pyatf": + res_dict = dict() + registered_params = list() + registered_restrictions = list() + parsed_restrictions_pyatf = list() + for param in tune_params.keys(): + registered_params.append(param) + for index, (res, params) in enumerate(parsed_restrictions): + if index in registered_restrictions: + continue + if all(p in registered_params for p in params): + if param not in res_dict: + res_dict[param] = (list(), list()) + res_dict[param][0].append(res) + res_dict[param][1].extend(params) + registered_restrictions.append(index) + # combine multiple restrictions into one + for res_tuple in res_dict.values(): + res, params_used = res_tuple + params_used = list( + dict.fromkeys(params_used) + ) # param_used should only contain unique, dict preserves order + parsed_restrictions_pyatf.append( + (f"def r({', '.join(params_used)}): return ({') and ('.join(res)}) \n", params_used) + ) + parsed_restrictions = parsed_restrictions_pyatf + else: + # create one monolithic function + parsed_restrictions = ") and (".join( + [re.sub(regex_match_variable, replace_params, res) for res in restrictions] + ) + + # tidy up the code by removing the last suffix and unnecessary spaces + parsed_restrictions = "(" + parsed_restrictions.strip() + ")" + parsed_restrictions = " ".join(parsed_restrictions.split()) + + # provide a mapping of the parameter names to the index in the tuple received + params_index = dict(zip(tune_params.keys(), range(len(tune_params.keys())))) + + if format == "pyatf": + parsed_restrictions = [ + ( + f"def restrictions({', '.join(params_index.keys())}): return {parsed_restrictions} \n", + list(tune_params.keys()), + ) + ] + else: + parsed_restrictions = [ + ( + f"def restrictions(*params): params_index = {params_index}; return {parsed_restrictions} \n", + list(tune_params.keys()), + ) + ] + + return parsed_restrictions + + +def get_all_lambda_asts(func): + """Extracts the AST nodes of all lambda functions defined on the same line as func. + + Args: + func: A lambda function object. + + Returns: + A list of all ast.Lambda node objects on the line where func is defined. + + Raises: + ValueError: If the source can't be retrieved or no lambda is found. + """ + res = [] + try: + source = getsource(func) + source = textwrap.dedent(source).strip() + parsed = ast.parse(source) + + # Find the Lambda node + for node in ast.walk(parsed): + if isinstance(node, ast.Lambda): + res.append(node) + if not res: + raise ValueError(f"No lambda node found in the source {source}.") + except SyntaxError: + """ Ignore syntax errors on the lambda """ + return res + except OSError: + raise ValueError("Could not retrieve source. Is this defined interactively or dynamically?") + return res + + +class InvalidLambdaError(Exception): + def __str__(self): + return "lambda could not be parsed by Kernel Tuner" + + +class ConstraintLambdaTransformer(ast.NodeTransformer): + """Replaces any `NAME['string']` subscript with just `'string'`, if `NAME` + matches the lambda argument name. + """ + def __init__(self, dict_arg_name): + self.dict_arg_name = dict_arg_name + + def visit_Name(self, node): + # If we find a Name node that is not part of a Subscript expression, then + # we throw an exception. This happens when a lambda contains a captured + # variable or calls a function. In these cases, we cannot transform the + # lambda into a string so we just exit the ast transformer. + raise InvalidLambdaError() + + def visit_Subscript(self, node): + # We only replace subscript expressions of the form ['some_string'] + if (isinstance(node.value, ast.Name) + and node.value.id == self.dict_arg_name + and isinstance(node.slice, ast.Constant) + and isinstance(node.slice.value, str)): + # Replace `dict_arg_name['some_key']` with the string used as key + return ast.Name(node.slice.value) + return self.generic_visit(node) + + +def unparse_constraint_lambda(lambda_ast): + """Parse the lambda function to replace accesses to tunable parameter dict + Returns string body of the rewritten lambda function + """ + args = lambda_ast.args + + # Kernel Tuner only allows constraint lambdas with a single argument + if len(args.args) != 1: + raise InvalidLambdaError() + + first_arg = args.args[0].arg + + # Create transformer that replaces accesses to tunable parameter dict + # with simply the name of the tunable parameter + transformer = ConstraintLambdaTransformer(first_arg) + new_lambda_ast = transformer.visit(lambda_ast) + + return ast.unparse(new_lambda_ast.body).strip() + + +def convert_constraint_lambdas(restrictions): + """Extract and convert all constraint lambdas from the restrictions""" + res = [] + for c in restrictions: + if isinstance(c, (str, Constraint)): + res.append(c) + elif callable(c): + try: + lambda_asts = get_all_lambda_asts(c) + res += [unparse_constraint_lambda(lambda_ast) for lambda_ast in lambda_asts] + except (InvalidLambdaError, ValueError): + res.append(c) # it's just a plain function, not a lambda + + + result = list(set(res)) + if not len(result) == len(restrictions): + raise ValueError("An error occured when parsing restrictions. If you mix lambdas and string-based restrictions, please define the lambda first.") + + return result + + +def compile_restrictions( + restrictions: list, tune_params: dict, monolithic=False, format=None +) -> list[tuple[Union[str, FunctionType], list[str], Union[str, None]]]: + """Parses restrictions from a list of strings into a list of strings or Functions and parameters used and source, or a single Function if monolithic is true.""" + restrictions = convert_constraint_lambdas(restrictions) + + # filter the restrictions to get only the strings + restrictions_str, restrictions_ignore = [], [] + for r in restrictions: + (restrictions_str if isinstance(r, str) else restrictions_ignore).append(r) + if len(restrictions_str) == 0: + return restrictions_ignore + + # parse the strings + parsed_restrictions = parse_restrictions(restrictions_str, tune_params, monolithic=monolithic, format=format) + + # compile the parsed restrictions into a function + compiled_restrictions: list[tuple] = list() + for restriction, params_used in parsed_restrictions: + if isinstance(restriction, str): + # if it's a string, parse it to a function + code_object = compile(restriction, "", "exec") + func = FunctionType(code_object.co_consts[0], globals()) + compiled_restrictions.append((func, params_used, restriction)) + elif isinstance(restriction, Constraint): + # otherwise it already is a Constraint, pass it directly + compiled_restrictions.append((restriction, params_used, None)) + else: + raise ValueError(f"Restriction {restriction} is neither a string or Constraint {type(restriction)}") + + # return the restrictions and used parameters + if len(restrictions_ignore) == 0: + return compiled_restrictions + + # use the required parameters or add an empty tuple for unknown parameters of ignored restrictions + noncompiled_restrictions = [] + for r in restrictions_ignore: + if isinstance(r, tuple) and len(r) == 2 and isinstance(r[1], (list, tuple)): + restriction, params_used = r + noncompiled_restrictions.append((restriction, params_used, restriction)) + else: + noncompiled_restrictions.append((r, [], r)) + return noncompiled_restrictions + compiled_restrictions diff --git a/kernel_tuner/util.py b/kernel_tuner/util.py index 386c5ce2..b7f443fa 100644 --- a/kernel_tuner/util.py +++ b/kernel_tuner/util.py @@ -1,36 +1,16 @@ """Module for kernel tuner utility functions.""" -import ast +from inspect import signature +from pathlib import Path import errno import json import logging +import numpy as np import os import re import sys import tempfile -import textwrap import time import warnings -from inspect import getsource, signature -from pathlib import Path -from types import FunctionType -from typing import Union - -import numpy as np -from constraint import ( - AllDifferentConstraint, - AllEqualConstraint, - Constraint, - ExactSumConstraint, - FunctionConstraint, - InSetConstraint, - MaxProdConstraint, - MaxSumConstraint, - MinProdConstraint, - MinSumConstraint, - NotInSetConstraint, - SomeInSetConstraint, - SomeNotInSetConstraint, -) from kernel_tuner.accuracy import Tunable @@ -271,135 +251,6 @@ def check_block_size_params_names_list(block_size_names, tune_params): return block_size_names - -def check_restriction(restrict, params: dict) -> bool: - """Check whether a configuration meets a search space restriction.""" - # if it's a python-constraint, convert to function and execute - if isinstance(restrict, Constraint): - restrict = convert_constraint_restriction(restrict) - return restrict(list(params.values())) - # if it's a string, fill in the parameters and evaluate - elif isinstance(restrict, str): - return eval(replace_param_occurrences(restrict, params)) - # if it's a function, call it - elif callable(restrict): - return restrict(**params) - # if it's a tuple, use only the parameters in the second argument to call the restriction - elif ( - isinstance(restrict, tuple) - and (len(restrict) == 2 or len(restrict) == 3) - and callable(restrict[0]) - and isinstance(restrict[1], (list, tuple)) - ): - # unpack the tuple - if len(restrict) == 2: - restrict, selected_params = restrict - else: - restrict, selected_params, source = restrict - # look up the selected parameters and their value - selected_params = dict((key, params[key]) for key in selected_params) - # call the restriction - if isinstance(restrict, Constraint): - restrict = convert_constraint_restriction(restrict) - return restrict(list(selected_params.values())) - else: - return restrict(**selected_params) - # otherwise, raise an error - else: - raise ValueError(f"Unknown restriction type {type(restrict)} ({restrict})") - - -def check_restrictions(restrictions, params: dict, verbose: bool) -> bool: - """Check whether a configuration meets the search space restrictions.""" - if callable(restrictions): - valid = restrictions(params) - if not valid and verbose is True: - print(f"skipping config {get_instance_string(params)}, reason: config fails restriction") - return valid - valid = True - for restrict in restrictions: - # Check the type of each restriction and validate accordingly. Re-implement as a switch when Python >= 3.10. - try: - valid = check_restriction(restrict, params) - if not valid: - break - except ZeroDivisionError: - logging.debug(f"Restriction {restrict} with configuration {get_instance_string(params)} divides by zero.") - if not valid and verbose is True: - print(f"skipping config {get_instance_string(params)}, reason: config fails restriction {restrict}") - return valid - - -def convert_constraint_restriction(restrict: Constraint): - """Convert the python-constraint to a function for backwards compatibility.""" - if isinstance(restrict, FunctionConstraint): - - def f_restrict(p): - return restrict._func(*p) - - elif isinstance(restrict, AllDifferentConstraint): - - def f_restrict(p): - return len(set(p)) == len(p) - - elif isinstance(restrict, AllEqualConstraint): - - def f_restrict(p): - return all(x == p[0] for x in p) - - elif isinstance(restrict, MaxProdConstraint): - - def f_restrict(p): - return np.prod(p) <= restrict._maxprod - - elif isinstance(restrict, MinProdConstraint): - - def f_restrict(p): - return np.prod(p) >= restrict._minprod - - elif isinstance(restrict, MaxSumConstraint): - - def f_restrict(p): - return sum(p) <= restrict._maxsum - - elif isinstance(restrict, ExactSumConstraint): - - def f_restrict(p): - return sum(p) == restrict._exactsum - - elif isinstance(restrict, MinSumConstraint): - - def f_restrict(p): - return sum(p) >= restrict._minsum - - elif isinstance(restrict, (InSetConstraint, NotInSetConstraint, SomeInSetConstraint, SomeNotInSetConstraint)): - raise NotImplementedError( - f"Restriction of the type {type(restrict)} is explicitly not supported in backwards compatibility mode, because the behaviour is too complex. Please rewrite this constraint to a function to use it with this algorithm." - ) - else: - raise TypeError(f"Unrecognized restriction {restrict}") - return f_restrict - - -def check_thread_block_dimensions(params, max_threads, block_size_names=None): - """Check on maximum thread block dimensions.""" - dims = get_thread_block_dimensions(params, block_size_names) - return np.prod(dims) <= max_threads - - -def config_valid(config, tuning_options, max_threads): - """Combines restrictions and a check on the max thread block dimension to check config validity.""" - legal = True - params = dict(zip(tuning_options.tune_params.keys(), config)) - if tuning_options.restrictions: - legal = check_restrictions(tuning_options.restrictions, params, False) - if not legal: - return False - block_size_names = tuning_options.get("block_size_names", None) - valid_thread_block_dimensions = check_thread_block_dimensions(params, max_threads, block_size_names) - return valid_thread_block_dimensions - - def delete_temp_file(filename): """Delete a temporary file, don't complain if no longer exists.""" try: @@ -899,260 +750,6 @@ def has_kw_argument(func, name): return lambda answer, result_host, atol: v(answer, result_host) -def parse_restrictions( - restrictions: list[str], tune_params: dict, monolithic=False, format=None -) -> list[tuple[Union[Constraint, str], list[str]]]: - """Parses restrictions from a list of strings into compilable functions and constraints, or a single compilable function (if monolithic is True). Returns a list of tuples of (strings or constraints) and parameters.""" - # rewrite the restrictions so variables are singled out - regex_match_variable = r"([a-zA-Z_$][a-zA-Z_$0-9]*)" - - def replace_params(match_object): - key = match_object.group(1) - if key in tune_params and format != "pyatf": - param = str(key) - return "params[params_index['" + param + "']]" - else: - return key - - def replace_params_split(match_object): - # careful: has side-effect of adding to set `params_used` - key = match_object.group(1) - if key in tune_params: - param = str(key) - params_used.add(param) - return param - else: - return key - - # remove functionally duplicate restrictions (preserves order and whitespace) - if all(isinstance(r, str) for r in restrictions): - # clean the restriction strings to functional equivalence - restrictions_cleaned = [r.replace(" ", "") for r in restrictions] - restrictions_cleaned_unique = list(dict.fromkeys(restrictions_cleaned)) # dict preserves order - # get the indices of the unique restrictions, use these to build a new list of restrictions - restrictions_unique_indices = [restrictions_cleaned.index(r) for r in restrictions_cleaned_unique] - restrictions = [restrictions[i] for i in restrictions_unique_indices] - - # create the parsed restrictions - if monolithic is False: - # split into functions that only take their relevant parameters - parsed_restrictions = list() - for res in restrictions: - params_used: set[str] = set() - parsed_restriction = re.sub(regex_match_variable, replace_params_split, res).strip() - params_used = list(params_used) - finalized_constraint = None - # we must turn it into a general function - if format is not None and format.lower() == "pyatf": - finalized_constraint = parsed_restriction - else: - finalized_constraint = f"def r({', '.join(params_used)}): return {parsed_restriction} \n" - parsed_restrictions.append((finalized_constraint, params_used)) - - # if pyATF, restrictions that are set on the same parameter must be combined into one - if format is not None and format.lower() == "pyatf": - res_dict = dict() - registered_params = list() - registered_restrictions = list() - parsed_restrictions_pyatf = list() - for param in tune_params.keys(): - registered_params.append(param) - for index, (res, params) in enumerate(parsed_restrictions): - if index in registered_restrictions: - continue - if all(p in registered_params for p in params): - if param not in res_dict: - res_dict[param] = (list(), list()) - res_dict[param][0].append(res) - res_dict[param][1].extend(params) - registered_restrictions.append(index) - # combine multiple restrictions into one - for res_tuple in res_dict.values(): - res, params_used = res_tuple - params_used = list( - dict.fromkeys(params_used) - ) # param_used should only contain unique, dict preserves order - parsed_restrictions_pyatf.append( - (f"def r({', '.join(params_used)}): return ({') and ('.join(res)}) \n", params_used) - ) - parsed_restrictions = parsed_restrictions_pyatf - else: - # create one monolithic function - parsed_restrictions = ") and (".join( - [re.sub(regex_match_variable, replace_params, res) for res in restrictions] - ) - - # tidy up the code by removing the last suffix and unnecessary spaces - parsed_restrictions = "(" + parsed_restrictions.strip() + ")" - parsed_restrictions = " ".join(parsed_restrictions.split()) - - # provide a mapping of the parameter names to the index in the tuple received - params_index = dict(zip(tune_params.keys(), range(len(tune_params.keys())))) - - if format == "pyatf": - parsed_restrictions = [ - ( - f"def restrictions({', '.join(params_index.keys())}): return {parsed_restrictions} \n", - list(tune_params.keys()), - ) - ] - else: - parsed_restrictions = [ - ( - f"def restrictions(*params): params_index = {params_index}; return {parsed_restrictions} \n", - list(tune_params.keys()), - ) - ] - - return parsed_restrictions - - -def get_all_lambda_asts(func): - """Extracts the AST nodes of all lambda functions defined on the same line as func. - - Args: - func: A lambda function object. - - Returns: - A list of all ast.Lambda node objects on the line where func is defined. - - Raises: - ValueError: If the source can't be retrieved or no lambda is found. - """ - res = [] - try: - source = getsource(func) - source = textwrap.dedent(source).strip() - parsed = ast.parse(source) - - # Find the Lambda node - for node in ast.walk(parsed): - if isinstance(node, ast.Lambda): - res.append(node) - if not res: - raise ValueError(f"No lambda node found in the source {source}.") - except SyntaxError: - """ Ignore syntax errors on the lambda """ - return res - except OSError: - raise ValueError("Could not retrieve source. Is this defined interactively or dynamically?") - return res - - -class InvalidLambdaError(Exception): - def __str__(self): - return "lambda could not be parsed by Kernel Tuner" - - -class ConstraintLambdaTransformer(ast.NodeTransformer): - """Replaces any `NAME['string']` subscript with just `'string'`, if `NAME` - matches the lambda argument name. - """ - def __init__(self, dict_arg_name): - self.dict_arg_name = dict_arg_name - - def visit_Name(self, node): - # If we find a Name node that is not part of a Subscript expression, then - # we throw an exception. This happens when a lambda contains a captured - # variable or calls a function. In these cases, we cannot transform the - # lambda into a string so we just exit the ast transformer. - raise InvalidLambdaError() - - def visit_Subscript(self, node): - # We only replace subscript expressions of the form ['some_string'] - if (isinstance(node.value, ast.Name) - and node.value.id == self.dict_arg_name - and isinstance(node.slice, ast.Constant) - and isinstance(node.slice.value, str)): - # Replace `dict_arg_name['some_key']` with the string used as key - return ast.Name(node.slice.value) - return self.generic_visit(node) - - -def unparse_constraint_lambda(lambda_ast): - """Parse the lambda function to replace accesses to tunable parameter dict - Returns string body of the rewritten lambda function - """ - args = lambda_ast.args - - # Kernel Tuner only allows constraint lambdas with a single argument - if len(args.args) != 1: - raise InvalidLambdaError() - - first_arg = args.args[0].arg - - # Create transformer that replaces accesses to tunable parameter dict - # with simply the name of the tunable parameter - transformer = ConstraintLambdaTransformer(first_arg) - new_lambda_ast = transformer.visit(lambda_ast) - - return ast.unparse(new_lambda_ast.body).strip() - - -def convert_constraint_lambdas(restrictions): - """Extract and convert all constraint lambdas from the restrictions""" - res = [] - for c in restrictions: - if isinstance(c, (str, Constraint)): - res.append(c) - elif callable(c): - try: - lambda_asts = get_all_lambda_asts(c) - res += [unparse_constraint_lambda(lambda_ast) for lambda_ast in lambda_asts] - except (InvalidLambdaError, ValueError): - res.append(c) # it's just a plain function, not a lambda - - - result = list(set(res)) - if not len(result) == len(restrictions): - raise ValueError("An error occured when parsing restrictions. If you mix lambdas and string-based restrictions, please define the lambda first.") - - return result - - -def compile_restrictions( - restrictions: list, tune_params: dict, monolithic=False, format=None -) -> list[tuple[Union[str, FunctionType], list[str], Union[str, None]]]: - """Parses restrictions from a list of strings into a list of strings or Functions and parameters used and source, or a single Function if monolithic is true.""" - restrictions = convert_constraint_lambdas(restrictions) - - # filter the restrictions to get only the strings - restrictions_str, restrictions_ignore = [], [] - for r in restrictions: - (restrictions_str if isinstance(r, str) else restrictions_ignore).append(r) - if len(restrictions_str) == 0: - return restrictions_ignore - - # parse the strings - parsed_restrictions = parse_restrictions(restrictions_str, tune_params, monolithic=monolithic, format=format) - - # compile the parsed restrictions into a function - compiled_restrictions: list[tuple] = list() - for restriction, params_used in parsed_restrictions: - if isinstance(restriction, str): - # if it's a string, parse it to a function - code_object = compile(restriction, "", "exec") - func = FunctionType(code_object.co_consts[0], globals()) - compiled_restrictions.append((func, params_used, restriction)) - elif isinstance(restriction, Constraint): - # otherwise it already is a Constraint, pass it directly - compiled_restrictions.append((restriction, params_used, None)) - else: - raise ValueError(f"Restriction {restriction} is neither a string or Constraint {type(restriction)}") - - # return the restrictions and used parameters - if len(restrictions_ignore) == 0: - return compiled_restrictions - - # use the required parameters or add an empty tuple for unknown parameters of ignored restrictions - noncompiled_restrictions = [] - for r in restrictions_ignore: - if isinstance(r, tuple) and len(r) == 2 and isinstance(r[1], (list, tuple)): - restriction, params_used = r - noncompiled_restrictions.append((restriction, params_used, restriction)) - else: - noncompiled_restrictions.append((r, [], r)) - return noncompiled_restrictions + compiled_restrictions def check_matching_problem_size(cached_problem_size, problem_size): """Check the if requested problem size matches the problem size in the cache.""" @@ -1318,7 +915,6 @@ def store_cache(key, params, tuning_options): with open(tuning_options.cachefile, "a") as cachefile: cachefile.write("\n" + json.dumps({key: output_params}, cls=NpEncoder)[1:-1] + ",") - def dump_cache(obj: str, tuning_options): """Dumps a string in the cache, this omits the several checks of store_cache() to speed up the process - with great power comes great responsibility!""" if isinstance(tuning_options.cache, dict) and tuning_options.cachefile: From 1198d7bb66499c8fed257b00176c56c66ac0b891 Mon Sep 17 00:00:00 2001 From: stijn Date: Wed, 14 Jan 2026 16:57:55 +0100 Subject: [PATCH 3/3] Clean up of some code in `restrictions.py` --- kernel_tuner/restrictions.py | 130 ++++++++++++++++++- kernel_tuner/searchspace.py | 234 ++++++++--------------------------- 2 files changed, 181 insertions(+), 183 deletions(-) diff --git a/kernel_tuner/restrictions.py b/kernel_tuner/restrictions.py index e74fb944..9c813656 100644 --- a/kernel_tuner/restrictions.py +++ b/kernel_tuner/restrictions.py @@ -41,7 +41,7 @@ def check_restriction(restrict, params: dict) -> bool: # if it's a tuple, use only the parameters in the second argument to call the restriction elif ( isinstance(restrict, tuple) - and (len(restrict) == 2 or len(restrict) == 3) + and len(restrict) in (2, 3) and callable(restrict[0]) and isinstance(restrict[1], (list, tuple)) ): @@ -170,7 +170,7 @@ def replace_params_split(match_object): restrictions = [restrictions[i] for i in restrictions_unique_indices] # create the parsed restrictions - if monolithic is False: + if not monolithic: # split into functions that only take their relevant parameters parsed_restrictions = list() for res in restrictions: @@ -355,7 +355,11 @@ def compile_restrictions( # filter the restrictions to get only the strings restrictions_str, restrictions_ignore = [], [] for r in restrictions: - (restrictions_str if isinstance(r, str) else restrictions_ignore).append(r) + if isinstance(r, str): + restrictions_str.append(r) + else: + restrictions_ignore.append(r) + if len(restrictions_str) == 0: return restrictions_ignore @@ -389,3 +393,123 @@ def compile_restrictions( else: noncompiled_restrictions.append((r, [], r)) return noncompiled_restrictions + compiled_restrictions + + +def parse_restrictions_pysmt(restrictions: list, tune_params: dict, symbols: dict): + """Parses restrictions from a list of strings into PySMT compatible restrictions.""" + from pysmt.shortcuts import ( + GE, + GT, + LE, + LT, + And, + Bool, + Div, + Equals, + Int, + Minus, + Or, + Plus, + Pow, + Real, + String, + Times, + ) + + regex_match_variable = r"([a-zA-Z_$][a-zA-Z_$0-9]*)" + + boolean_comparison_mapping = { + "==": Equals, + "<": LT, + "<=": LE, + ">=": GE, + ">": GT, + "&&": And, + "||": Or, + } + + operators_mapping = {"+": Plus, "-": Minus, "*": Times, "/": Div, "^": Pow} + + constant_init_mapping = { + "int": Int, + "float": Real, + "str": String, + "bool": Bool, + } + + def replace_params(match_object): + key = match_object.group(1) + if key in tune_params: + return 'params["' + key + '"]' + else: + return key + + # rewrite the restrictions so variables are singled out + parsed = [re.sub(regex_match_variable, replace_params, res) for res in restrictions] + # ensure no duplicates are in the list + parsed = list(set(parsed)) + # replace ' or ' and ' and ' with ' || ' and ' && ' + parsed = list(r.replace(" or ", " || ").replace(" and ", " && ") for r in parsed) + + # compile each restriction by replacing parameters and operators with their PySMT equivalent + compiled_restrictions = list() + for parsed_restriction in parsed: + words = parsed_restriction.split(" ") + + # make a forward pass over all the words to organize and substitute + add_next_var_or_constant = False + var_or_constant_backlog = list() + operator_backlog = list() + operator_backlog_left_right = list() + boolean_backlog = list() + for word in words: + if word.startswith("params["): + # if variable + varname = word.replace('params["', "").replace('"]', "") + var = symbols[varname] + var_or_constant_backlog.append(var) + elif word in boolean_comparison_mapping: + # if comparator + boolean_backlog.append(boolean_comparison_mapping[word]) + continue + elif word in operators_mapping: + # if operator + operator_backlog.append(operators_mapping[word]) + add_next_var_or_constant = True + continue + else: + # if constant: evaluate to check if it is an integer, float, etc. If not, treat it as a string. + try: + constant = ast.literal_eval(word) + except ValueError: + constant = word + # convert from Python type to PySMT equivalent + type_instance = constant_init_mapping[type(constant).__name__] + var_or_constant_backlog.append(type_instance(constant)) + if add_next_var_or_constant: + right, left = var_or_constant_backlog.pop(-1), var_or_constant_backlog.pop(-1) + operator_backlog_left_right.append((left, right, len(var_or_constant_backlog))) + add_next_var_or_constant = False + # reserve an empty spot for the combined operation to preserve the order + var_or_constant_backlog.append(None) + + # for each of the operators, instantiate them with variables or constants + for i, operator in enumerate(operator_backlog): + # merges the first two symbols in the backlog into one + left, right, new_index = operator_backlog_left_right[i] + assert ( + var_or_constant_backlog[new_index] is None + ) # make sure that this is a reserved spot to avoid changing the order + var_or_constant_backlog[new_index] = operator(left, right) + + # for each of the booleans, instantiate them with variables or constants + compiled = list() + assert len(boolean_backlog) <= 1, "Max. one boolean operator per restriction." + for boolean in boolean_backlog: + left, right = var_or_constant_backlog.pop(0), var_or_constant_backlog.pop(0) + compiled.append(boolean(left, right)) + + # add the restriction to the list of restrictions + compiled_restrictions.append(compiled[0]) + + return And(compiled_restrictions) diff --git a/kernel_tuner/searchspace.py b/kernel_tuner/searchspace.py index d3d00052..31a48d52 100644 --- a/kernel_tuner/searchspace.py +++ b/kernel_tuner/searchspace.py @@ -31,10 +31,13 @@ except ImportError: torch_available = False -from kernel_tuner.util import check_restrictions as check_instance_restrictions -from kernel_tuner.util import ( +from kernel_tuner.restrictions import check_restrictions as check_instance_restrictions +from kernel_tuner.restrictions import ( compile_restrictions, convert_constraint_lambdas, + parse_restrictions_pysmt, +) +from kernel_tuner.util import ( default_block_size_names, get_interval, ) @@ -78,9 +81,16 @@ def __init__( if from_cache is None: assert tune_params is not None and max_threads is not None, "Must specify positional arguments." + # Normalize `restrictions` to a list of items + if restrictions is None: + restrictions = [] + elif isinstance(restrictions, (tuple, list)): + restrictions = list(restrictions) # Create copy + else: + restrictions = [restrictions] + # set the object attributes using the arguments framework_l = framework.lower() - restrictions = restrictions if restrictions is not None else [] self.tune_params = tune_params self.original_tune_params = tune_params.copy() if hasattr(tune_params, "copy") else tune_params self.max_threads = max_threads @@ -95,7 +105,6 @@ def __init__( self._tensorspace_param_config_structure = [] self._map_tensor_to_param = {} self._map_param_to_tensor = {} - restrictions = [restrictions] if not isinstance(restrictions, (list, tuple)) else restrictions self.restrictions = deepcopy(restrictions) self.original_restrictions = deepcopy(restrictions) # keep the original restrictions, so that the searchspace can be modified later # the searchspace can add commonly used constraints (e.g. maxprod(blocks) <= maxthreads) @@ -123,18 +132,15 @@ def __init__( raise ValueError(f"Neighbor method is {neighbor_method}, must be one of {supported_neighbor_methods}") # if there are strings in the restrictions, parse them to split constraints or functions (improves solver performance) - restrictions = [restrictions] if not isinstance(restrictions, list) else restrictions if ( - len(restrictions) > 0 - and ( + ( any(isinstance(restriction, str) for restriction in restrictions) or any( isinstance(restriction[0], str) for restriction in restrictions if isinstance(restriction, tuple) ) ) - and not ( - framework_l == "pysmt" or framework_l == "bruteforce" or framework_l == "pythonconstraint" or solver_method.lower() == "pc_parallelsolver" - ) + and framework_l not in ("pysmt", "bruteforce" "pythonconstraint") + and solver_method.lower() != "pc_parallelsolver" ): self.restrictions = compile_restrictions( restrictions, @@ -166,7 +172,6 @@ def __init__( raise ValueError(f"Invalid framework parameter {framework}") # get the solver given the solver method argument - solver = "" if solver_method.lower() == "pc_backtrackingsolver": solver = BacktrackingSolver() elif solver_method.lower() == "pc_optimizedbacktrackingsolver": @@ -231,10 +236,8 @@ def __init__( def __build_searchspace_bruteforce(self, block_size_names: list, max_threads: int, solver=None): # bruteforce solving of the searchspace - from itertools import product - - from kernel_tuner.util import check_restrictions + from kernel_tuner.restrictions import check_restrictions tune_params = self.tune_params restrictions = self.restrictions @@ -246,11 +249,11 @@ def __build_searchspace_bruteforce(self, block_size_names: list, max_threads: in used_block_size_names = list( block_size_name for block_size_name in default_block_size_names if block_size_name in tune_params ) + if len(used_block_size_names) > 0: - if not isinstance(restrictions, list): - restrictions = [restrictions] block_size_restriction_spaced = f"{' * '.join(used_block_size_names)} <= {max_threads}" block_size_restriction_unspaced = f"{'*'.join(used_block_size_names)} <= {max_threads}" + if ( block_size_restriction_spaced not in restrictions and block_size_restriction_unspaced not in restrictions @@ -266,7 +269,7 @@ def __build_searchspace_bruteforce(self, block_size_names: list, max_threads: in self.restrictions.append(block_size_restriction_spaced) # check for search space restrictions - if restrictions is not None: + if len(restrictions) > 0: parameter_space = filter( lambda p: check_restrictions(restrictions, dict(zip(tune_params.keys(), p)), False), parameter_space ) @@ -313,7 +316,7 @@ def all_smt(formula, keys) -> list: domains = And(domains) # add the restrictions - problem = self.__parse_restrictions_pysmt(restrictions, tune_params, symbols) + problem = parse_restrictions_pysmt(restrictions, tune_params, symbols) # combine the domain and restrictions formula = And(domains, problem) @@ -409,7 +412,7 @@ def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, so # Define a bogus cost function costfunc = CostFunction(":") # bash no-op - + # set data self.tune_params_pyatf = self.get_tune_params_pyatf(block_size_names, max_threads) @@ -423,16 +426,17 @@ def __build_searchspace_pyATF(self, block_size_names: list, max_threads: int, so parameter_tuple_list = list() for entry in tuning_data.history._entries: parameter_tuple_list.append(tuple(entry.configuration[p] for p in tune_params.keys())) - pl = self.__parameter_space_list_to_lookup_and_return_type(parameter_tuple_list) - return pl + + return self.__parameter_space_list_to_lookup_and_return_type(parameter_tuple_list) def __build_searchspace_ATF_cache(self, block_size_names: list, max_threads: int, solver: Solver): + import pandas as pd + """Imports the valid configurations from an ATF CSV file, returns the searchspace, a dict of the searchspace for fast lookups and the size.""" if block_size_names != default_block_size_names or max_threads != 1024: raise ValueError( "It is not possible to change 'block_size_names' or 'max_threads here, because at this point ATF has already ran.'" ) - import pandas as pd try: df = pd.read_csv(self.path_to_ATF_cache, sep=";") @@ -459,7 +463,7 @@ def __parameter_space_list_to_lookup_and_return_type( parameter_space_dict, size_list, ) - + def __build_searchspace(self, block_size_names: list, max_threads: int, solver: Solver): """Compute valid configurations in a search space based on restrictions and max_threads.""" # instantiate the parameter space with all the variables @@ -468,8 +472,6 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver: parameter_space.addVariable(str(param_name), param_values) # add the user-specified restrictions as constraints on the parameter space - if not isinstance(self.restrictions, (list, tuple)): - self.restrictions = [self.restrictions] if any(not isinstance(restriction, (Constraint, FunctionConstraint, str)) for restriction in self.restrictions): self.restrictions = convert_constraint_lambdas(self.restrictions) parameter_space = self.__add_restrictions(parameter_space) @@ -486,8 +488,7 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver: and max_block_size_product not in self._modified_restrictions ): self._modified_restrictions.append(max_block_size_product) - if isinstance(self.restrictions, list): - self.restrictions.append((MaxProdConstraint(max_threads), valid_block_size_names, None)) + self.restrictions.append((MaxProdConstraint(max_threads), valid_block_size_names, None)) # construct the parameter space with the constraints applied return parameter_space.getSolutionsAsListDict(order=self.param_names) @@ -495,165 +496,38 @@ def __build_searchspace(self, block_size_names: list, max_threads: int, solver: def __add_restrictions(self, parameter_space: Problem) -> Problem: """Add the user-specified restrictions as constraints on the parameter space.""" restrictions = deepcopy(self.restrictions) - if isinstance(restrictions, list): - for restriction in restrictions: - required_params = self.param_names - - # (un)wrap where necessary - if isinstance(restriction, tuple) and len(restriction) >= 2: - required_params = restriction[1] - restriction = restriction[0] - if callable(restriction) and not isinstance(restriction, Constraint): - # def restrictions_wrapper(*args): - # return check_instance_restrictions(restriction, dict(zip(self.param_names, args)), False) - # print(restriction, isinstance(restriction, Constraint)) - # restriction = FunctionConstraint(restrictions_wrapper) - restriction = FunctionConstraint(restriction, required_params) - - # add as a Constraint - all_params_required = all(param_name in required_params for param_name in self.param_names) - variables = None if all_params_required else required_params - if isinstance(restriction, FunctionConstraint): - parameter_space.addConstraint(restriction, variables) - elif isinstance(restriction, Constraint): - parameter_space.addConstraint(restriction, variables) - elif isinstance(restriction, str): - if self.solver_method.lower() == "pc_parallelsolver": - parameter_space.addConstraint(restriction) - else: - parameter_space.addConstraint(restriction, variables) - else: - raise ValueError(f"Unrecognized restriction type {type(restriction)} ({restriction})") - # if the restrictions are the old monolithic function, apply them directly (only for backwards compatibility, likely slower than well-specified constraints!) - elif callable(restrictions): + for restriction in restrictions: + required_params = self.param_names - def restrictions_wrapper(*args): - return check_instance_restrictions(restrictions, dict(zip(self.param_names, args)), False) + # (un)wrap where necessary + if isinstance(restriction, tuple) and len(restriction) >= 2: + required_params = restriction[1] + restriction = restriction[0] - parameter_space.addConstraint(FunctionConstraint(restrictions_wrapper), self.param_names) - elif restrictions is not None: - raise ValueError(f"The restrictions are of unsupported type {type(restrictions)}") - return parameter_space + if callable(restriction) and not isinstance(restriction, Constraint): + def restrictions_wrapper(*args): + return check_instance_restrictions(restriction, dict(zip(self.param_names, args))) - def __parse_restrictions_pysmt(self, restrictions: list, tune_params: dict, symbols: dict): - """Parses restrictions from a list of strings into PySMT compatible restrictions.""" - from pysmt.shortcuts import ( - GE, - GT, - LE, - LT, - And, - Bool, - Div, - Equals, - Int, - Minus, - Or, - Plus, - Pow, - Real, - String, - Times, - ) + restriction = FunctionConstraint(restrictions_wrapper) - regex_match_variable = r"([a-zA-Z_$][a-zA-Z_$0-9]*)" - - boolean_comparison_mapping = { - "==": Equals, - "<": LT, - "<=": LE, - ">=": GE, - ">": GT, - "&&": And, - "||": Or, - } - - operators_mapping = {"+": Plus, "-": Minus, "*": Times, "/": Div, "^": Pow} - - constant_init_mapping = { - "int": Int, - "float": Real, - "str": String, - "bool": Bool, - } - - def replace_params(match_object): - key = match_object.group(1) - if key in tune_params: - return 'params["' + key + '"]' - else: - return key - - # rewrite the restrictions so variables are singled out - parsed = [re.sub(regex_match_variable, replace_params, res) for res in restrictions] - # ensure no duplicates are in the list - parsed = list(set(parsed)) - # replace ' or ' and ' and ' with ' || ' and ' && ' - parsed = list(r.replace(" or ", " || ").replace(" and ", " && ") for r in parsed) - - # compile each restriction by replacing parameters and operators with their PySMT equivalent - compiled_restrictions = list() - for parsed_restriction in parsed: - words = parsed_restriction.split(" ") - - # make a forward pass over all the words to organize and substitute - add_next_var_or_constant = False - var_or_constant_backlog = list() - operator_backlog = list() - operator_backlog_left_right = list() - boolean_backlog = list() - for word in words: - if word.startswith("params["): - # if variable - varname = word.replace('params["', "").replace('"]', "") - var = symbols[varname] - var_or_constant_backlog.append(var) - elif word in boolean_comparison_mapping: - # if comparator - boolean_backlog.append(boolean_comparison_mapping[word]) - continue - elif word in operators_mapping: - # if operator - operator_backlog.append(operators_mapping[word]) - add_next_var_or_constant = True - continue + # add as a Constraint + all_params_required = all(param_name in required_params for param_name in self.param_names) + variables = None if all_params_required else required_params + + if isinstance(restriction, Constraint): + parameter_space.addConstraint(restriction, variables) + elif isinstance(restriction, str): + if self.solver_method.lower() == "pc_parallelsolver": + parameter_space.addConstraint(restriction) else: - # if constant: evaluate to check if it is an integer, float, etc. If not, treat it as a string. - try: - constant = ast.literal_eval(word) - except ValueError: - constant = word - # convert from Python type to PySMT equivalent - type_instance = constant_init_mapping[type(constant).__name__] - var_or_constant_backlog.append(type_instance(constant)) - if add_next_var_or_constant: - right, left = var_or_constant_backlog.pop(-1), var_or_constant_backlog.pop(-1) - operator_backlog_left_right.append((left, right, len(var_or_constant_backlog))) - add_next_var_or_constant = False - # reserve an empty spot for the combined operation to preserve the order - var_or_constant_backlog.append(None) - - # for each of the operators, instantiate them with variables or constants - for i, operator in enumerate(operator_backlog): - # merges the first two symbols in the backlog into one - left, right, new_index = operator_backlog_left_right[i] - assert ( - var_or_constant_backlog[new_index] is None - ) # make sure that this is a reserved spot to avoid changing the order - var_or_constant_backlog[new_index] = operator(left, right) - - # for each of the booleans, instantiate them with variables or constants - compiled = list() - assert len(boolean_backlog) <= 1, "Max. one boolean operator per restriction." - for boolean in boolean_backlog: - left, right = var_or_constant_backlog.pop(0), var_or_constant_backlog.pop(0) - compiled.append(boolean(left, right)) - - # add the restriction to the list of restrictions - compiled_restrictions.append(compiled[0]) - - return And(compiled_restrictions) + parameter_space.addConstraint(restriction, variables) + else: + raise ValueError(f"Unrecognized restriction type {type(restriction)} ({restriction})") + + + return parameter_space + def sorted_list(self, sort_last_param_first=False): """Returns list of parameter configs sorted based on the order in which the parameter values were specified.