diff --git a/.gitignore b/.gitignore index 479c7188..c4b5180a 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ push_to_pypi.sh !kernel_tuner/schema/T1/1.0.0/input-schema.json !test/test_T1_input.json !test_cache_file*.json +!test/regression/baselines/*.json *.csv .cache *.ipynb_checkpoints diff --git a/kernel_tuner/backends/nvcuda.py b/kernel_tuner/backends/nvcuda.py index 15259cb2..86254606 100644 --- a/kernel_tuner/backends/nvcuda.py +++ b/kernel_tuner/backends/nvcuda.py @@ -9,10 +9,20 @@ # embedded in try block to be able to generate documentation # and run tests without cuda-python installed +# Support both cuda-python < 13 and >= 13 import structures try: - from cuda import cuda, cudart, nvrtc + # cuda-python >= 13 uses cuda.bindings module + from cuda.bindings import driver as cuda + from cuda.bindings import runtime as cudart + from cuda.bindings import nvrtc except ImportError: - cuda = None + try: + # cuda-python < 13 uses direct imports + from cuda import cuda, cudart, nvrtc + except ImportError: + cuda = None + cudart = None + nvrtc = None class CudaFunctions(GPUBackend): diff --git a/kernel_tuner/interface.py b/kernel_tuner/interface.py index 32e91c86..c5575972 100644 --- a/kernel_tuner/interface.py +++ b/kernel_tuner/interface.py @@ -799,7 +799,7 @@ def run_kernel( try: # create kernel instance instance = dev.create_kernel_instance(kernelsource, kernel_options, params, False) - if instance is None: + if isinstance(instance, util.InvalidConfig): raise RuntimeError("cannot create kernel instance, too many threads per block") # see if the kernel arguments have correct type @@ -821,7 +821,7 @@ def run_kernel( dev.copy_texture_memory_args(texmem_args) finally: # delete temp files - if instance is not None: + if instance is not None and not isinstance(instance, util.ErrorConfig): instance.delete_temp_files() # run the kernel diff --git a/kernel_tuner/observers/nvcuda.py b/kernel_tuner/observers/nvcuda.py index c0a33ad5..af428973 100644 --- a/kernel_tuner/observers/nvcuda.py +++ b/kernel_tuner/observers/nvcuda.py @@ -1,9 +1,15 @@ import numpy as np +# Support both cuda-python < 13 and >= 13 import structures try: - from cuda import cudart + # cuda-python >= 13 uses cuda.bindings module + from cuda.bindings import runtime as cudart except ImportError: - cuda = None + try: + # cuda-python < 13 uses direct imports + from cuda import cudart + except ImportError: + cudart = None from kernel_tuner.observers.observer import BenchmarkObserver from kernel_tuner.util import cuda_error_check diff --git a/kernel_tuner/searchspace.py b/kernel_tuner/searchspace.py index d3d00052..16efac37 100644 --- a/kernel_tuner/searchspace.py +++ b/kernel_tuner/searchspace.py @@ -504,11 +504,18 @@ def __add_restrictions(self, parameter_space: Problem) -> Problem: required_params = restriction[1] restriction = restriction[0] if callable(restriction) and not isinstance(restriction, Constraint): - # def restrictions_wrapper(*args): - # return check_instance_restrictions(restriction, dict(zip(self.param_names, args)), False) - # print(restriction, isinstance(restriction, Constraint)) - # restriction = FunctionConstraint(restrictions_wrapper) - restriction = FunctionConstraint(restriction, required_params) + # Wrap the restriction to convert positional args to keyword args for backwards compatibility + # Old API: restriction received keyword args (via **params unpacking) + # New API: FunctionConstraint passes positional args + original_restriction = restriction + params_for_wrapper = required_params + + def make_wrapper(func, param_names): + def restrictions_wrapper(*args): + return func(**dict(zip(param_names, args))) + return restrictions_wrapper + + restriction = FunctionConstraint(make_wrapper(original_restriction, params_for_wrapper), required_params) # add as a Constraint all_params_required = all(param_name in required_params for param_name in self.param_names) @@ -1421,13 +1428,19 @@ def get_random_neighbor(self, param_config: tuple, neighbor_method=None, use_par return choice(neighbors) def get_param_neighbors(self, param_config: tuple, index: int, neighbor_method: str, randomize: bool) -> list: - """Get the neighboring parameters at an index.""" + """Get the neighboring parameters at an index. + + Only returns values from neighbors that differ ONLY at the specified index, + not in multiple places. This ensures that changing only this parameter + produces a valid configuration in the searchspace. + """ original_value = param_config[index] params = list( set( neighbor[index] for neighbor in self.get_neighbors(param_config, neighbor_method) if neighbor[index] != original_value + and all(neighbor[i] == param_config[i] for i in range(len(param_config)) if i != index) ) ) if randomize: diff --git a/kernel_tuner/util.py b/kernel_tuner/util.py index 2d9e3f1b..b31d9808 100644 --- a/kernel_tuner/util.py +++ b/kernel_tuner/util.py @@ -38,10 +38,20 @@ import cupy as cp except ImportError: cp = np +# Support both cuda-python < 13 and >= 13 import structures try: - from cuda import cuda, cudart, nvrtc + # cuda-python >= 13 uses cuda.bindings module + from cuda.bindings import driver as cuda + from cuda.bindings import runtime as cudart + from cuda.bindings import nvrtc except ImportError: - cuda = None + try: + # cuda-python < 13 uses direct imports + from cuda import cuda, cudart, nvrtc + except ImportError: + cuda = None + cudart = None + nvrtc = None from kernel_tuner.observers.nvml import NVMLObserver @@ -1077,13 +1087,29 @@ def unparse_constraint_lambda(lambda_ast): return rewritten_lambda_body_as_string +def has_closure_variables(func): + """Check if a function has captured closure variables.""" + return func.__closure__ is not None and len(func.__closure__) > 0 + + def convert_constraint_lambdas(restrictions): - """Extract and convert all constraint lambdas from the restrictions""" + """Extract and convert all constraint lambdas from the restrictions. + + Lambdas with captured closure variables are kept as-is to preserve + the closure context. Only simple lambdas without closures are converted + to strings for the constraint solver. + """ res = [] for c in restrictions: if isinstance(c, (str, Constraint)): res.append(c) if callable(c) and not isinstance(c, Constraint): + # If the lambda has closure variables, keep it as a callable + # to preserve the captured variable context + if has_closure_variables(c): + res.append(c) + continue + try: lambda_asts = get_all_lambda_asts(c) except ValueError: diff --git a/test/regression/__init__.py b/test/regression/__init__.py new file mode 100644 index 00000000..fea69674 --- /dev/null +++ b/test/regression/__init__.py @@ -0,0 +1 @@ +# Regression tests for Kernel Tuner diff --git a/test/regression/baselines/vector_add_NVIDIA_RTX_A4000.json b/test/regression/baselines/vector_add_NVIDIA_RTX_A4000.json new file mode 100644 index 00000000..3299441c --- /dev/null +++ b/test/regression/baselines/vector_add_NVIDIA_RTX_A4000.json @@ -0,0 +1,313 @@ +{ + "device_name": "NVIDIA RTX A4000", + "kernel_name": "vector_add", + "tune_params_keys": [ + "block_size_x" + ], + "tune_params": { + "block_size_x": [ + 128, + 192, + 256, + 320, + 384, + 448, + 512, + 576, + 640, + 704, + 768, + 832, + 896, + 960, + 1024 + ] + }, + "cache": { + "128": { + "block_size_x": 128, + "time": 0.04073600071881499, + "times": [ + 0.1268800050020218, + 0.031072000041604042, + 0.027295999228954315, + 0.025472000241279602, + 0.025119999423623085, + 0.025248000398278236, + 0.024064000695943832 + ], + "compile_time": 440.9545585513115, + "verification_time": 0, + "benchmark_time": 1.091592013835907, + "strategy_time": 0, + "framework_time": 0.8587837219238281, + "timestamp": "2022-12-23 12:11:26.411558+00:00" + }, + "192": { + "block_size_x": 192, + "time": 0.04095085710287094, + "times": [ + 0.12908799946308136, + 0.03046399913728237, + 0.027744000777602196, + 0.025151999667286873, + 0.024960000067949295, + 0.024992000311613083, + 0.02425600029528141 + ], + "compile_time": 436.15153804421425, + "verification_time": 0, + "benchmark_time": 1.0972395539283752, + "strategy_time": 0, + "framework_time": 1.6656816005706787, + "timestamp": "2022-12-23 12:11:26.411558+00:00" + }, + "256": { + "block_size_x": 256, + "time": 0.04189257137477398, + "times": [ + 0.13180799782276154, + 0.031136000528931618, + 0.028095999732613564, + 0.027008000761270523, + 0.025087999179959297, + 0.02505600079894066, + 0.02505600079894066 + ], + "compile_time": 436.5839697420597, + "verification_time": 0, + "benchmark_time": 1.0691732168197632, + "strategy_time": 0, + "framework_time": 1.6054585576057434, + "timestamp": "2022-12-23 12:11:26.411558+00:00" + }, + "320": { + "block_size_x": 320, + "time": 0.04208914376795292, + "times": [ + 0.1358720064163208, + 0.030688000842928886, + 0.02768000029027462, + 0.02582399919629097, + 0.025087999179959297, + 0.025312000885605812, + 0.024159999564290047 + ], + "compile_time": 438.9761835336685, + "verification_time": 0, + "benchmark_time": 1.0976120829582214, + "strategy_time": 0, + "framework_time": 1.4494173228740692, + "timestamp": "2022-12-23 12:11:26.411558+00:00" + }, + "384": { + "block_size_x": 384, + "time": 0.04174171467976911, + "times": [ + 0.13251200318336487, + 0.03167999908328056, + 0.027871999889612198, + 0.025312000885605812, + 0.024671999737620354, + 0.02505600079894066, + 0.025087999179959297 + ], + "compile_time": 440.71199372410774, + "verification_time": 0, + "benchmark_time": 1.0499358177185059, + "strategy_time": 0, + "framework_time": 1.682564616203308, + "timestamp": "2022-12-23 12:11:26.411558+00:00" + }, + "448": { + "block_size_x": 448, + "time": 0.03249828570655414, + "times": [ + 0.0647680014371872, + 0.03167999908328056, + 0.028255999088287354, + 0.025280000641942024, + 0.027103999629616737, + 0.02550400048494339, + 0.02489599958062172 + ], + "compile_time": 449.13655519485474, + "verification_time": 0, + "benchmark_time": 1.1196956038475037, + "strategy_time": 0, + "framework_time": 1.5890561044216156, + "timestamp": "2022-12-23 12:11:26.411558+00:00" + }, + "512": { + "block_size_x": 512, + "time": 0.04139885599059718, + "times": [ + 0.13023999333381653, + 0.031136000528931618, + 0.02831999957561493, + 0.02595200017094612, + 0.024607999250292778, + 0.025151999667286873, + 0.024383999407291412 + ], + "compile_time": 440.5844733119011, + "verification_time": 0, + "benchmark_time": 1.09076127409935, + "strategy_time": 0, + "framework_time": 1.853298395872116, + "timestamp": "2022-12-23 12:11:26.411558+00:00" + }, + "576": { + "block_size_x": 576, + "time": 0.04189257137477398, + "times": [ + 0.12995199859142303, + 0.03200000151991844, + 0.028511999174952507, + 0.026623999699950218, + 0.025760000571608543, + 0.02537599951028824, + 0.02502400055527687 + ], + "compile_time": 442.16764718294144, + "verification_time": 0, + "benchmark_time": 1.1038780212402344, + "strategy_time": 0, + "framework_time": 1.8403716385364532, + "timestamp": "2022-12-23 12:11:26.411558+00:00" + }, + "640": { + "block_size_x": 640, + "time": 0.0411702852163996, + "times": [ + 0.12796799838542938, + 0.03081599995493889, + 0.02969600073993206, + 0.025439999997615814, + 0.02409599907696247, + 0.02582399919629097, + 0.024351999163627625 + ], + "compile_time": 437.98910081386566, + "verification_time": 0, + "benchmark_time": 1.0496266186237335, + "strategy_time": 0, + "framework_time": 1.8264725804328918, + "timestamp": "2022-12-23 12:11:26.411558+00:00" + }, + "704": { + "block_size_x": 704, + "time": 0.04184228580977235, + "times": [ + 0.1343040019273758, + 0.03094400092959404, + 0.02908799983561039, + 0.025151999667286873, + 0.02486399933695793, + 0.024447999894618988, + 0.02409599907696247 + ], + "compile_time": 443.51235404610634, + "verification_time": 0, + "benchmark_time": 1.1033527553081512, + "strategy_time": 0, + "framework_time": 1.6709677875041962, + "timestamp": "2022-12-23 12:11:26.411558+00:00" + }, + "768": { + "block_size_x": 768, + "time": 0.03175771422684193, + "times": [ + 0.06230400130152702, + 0.0315839983522892, + 0.02831999957561493, + 0.02672000043094158, + 0.023679999634623528, + 0.023903999477624893, + 0.02579200081527233 + ], + "compile_time": 450.4409395158291, + "verification_time": 0, + "benchmark_time": 1.101326197385788, + "strategy_time": 0, + "framework_time": 1.7531625926494598, + "timestamp": "2022-12-23 12:11:26.411558+00:00" + }, + "832": { + "block_size_x": 832, + "time": 0.040941715240478516, + "times": [ + 0.12998400628566742, + 0.03094400092959404, + 0.027103999629616737, + 0.024768000468611717, + 0.025439999997615814, + 0.023903999477624893, + 0.024447999894618988 + ], + "compile_time": 439.9200603365898, + "verification_time": 0, + "benchmark_time": 1.0421127080917358, + "strategy_time": 0, + "framework_time": 2.1368376910686493, + "timestamp": "2022-12-23 12:11:26.411558+00:00" + }, + "896": { + "block_size_x": 896, + "time": 0.04177371359297207, + "times": [ + 0.12931199371814728, + 0.03731200098991394, + 0.02812799997627735, + 0.02502400055527687, + 0.02412799932062626, + 0.024768000468611717, + 0.023744000121951103 + ], + "compile_time": 439.23527002334595, + "verification_time": 0, + "benchmark_time": 1.0946877300739288, + "strategy_time": 0, + "framework_time": 2.03637033700943, + "timestamp": "2022-12-23 12:11:26.411558+00:00" + }, + "960": { + "block_size_x": 960, + "time": 0.042189714631864, + "times": [ + 0.1335040032863617, + 0.031039999797940254, + 0.02876799926161766, + 0.02579200081527233, + 0.025119999423623085, + 0.02566399984061718, + 0.025439999997615814 + ], + "compile_time": 441.7596235871315, + "verification_time": 0, + "benchmark_time": 1.1166557669639587, + "strategy_time": 0, + "framework_time": 1.7383433878421783, + "timestamp": "2022-12-23 12:11:26.411558+00:00" + }, + "1024": { + "block_size_x": 1024, + "time": 0.04114742816558906, + "times": [ + 0.13087999820709229, + 0.03049599938094616, + 0.027936000376939774, + 0.02486399933695793, + 0.0244159996509552, + 0.024320000782608986, + 0.025119999423623085 + ], + "compile_time": 442.8337663412094, + "verification_time": 0, + "benchmark_time": 1.0683201253414154, + "strategy_time": 0, + "framework_time": 1.9918642938137054, + "timestamp": "2022-12-23 12:11:26.411558+00:00" + } + } +} \ No newline at end of file diff --git a/test/regression/test_regression.py b/test/regression/test_regression.py new file mode 100644 index 00000000..5bf87722 --- /dev/null +++ b/test/regression/test_regression.py @@ -0,0 +1,195 @@ +"""Regression tests for Kernel Tuner measurement accuracy. + +These tests verify that Kernel Tuner produces consistent results +over time by comparing against known-good baseline data. + +Issue: https://github.com/KernelTuner/kernel_tuner/issues/99 +""" +import json +from pathlib import Path + +import numpy as np +import pytest + +from kernel_tuner import tune_kernel + +# Directory containing baseline JSON files +BASELINES_DIR = Path(__file__).parent / "baselines" + + +def load_baseline(filename: str) -> dict: + """Load a baseline JSON file. + + Args: + filename: Name of the baseline file in the baselines directory + + Returns: + Dictionary containing the baseline data + """ + filepath = BASELINES_DIR / filename + with open(filepath, "r") as f: + return json.load(f) + + +def compare_timing_results(actual: list, expected: dict, tolerance: float = 0.10): + """Compare actual tuning results against expected baseline. + + Args: + actual: List of result dictionaries from tune_kernel() + expected: Baseline cache dictionary with 'cache' key + tolerance: Allowed relative difference (default 10%) + + Returns: + List of (config_key, expected_time, actual_time, diff_pct) for failures + """ + failures = [] + + for result in actual: + # Build the cache key from result parameters + config_key = str(result.get("block_size_x", "")) + + if config_key not in expected["cache"]: + continue + + expected_time = expected["cache"][config_key]["time"] + actual_time = result["time"] + + # Skip error results (non-numeric times indicate compilation/runtime failures) + if not isinstance(actual_time, (int, float)): + continue + + # Calculate relative difference + if expected_time > 0: + diff_pct = abs(actual_time - expected_time) / expected_time + if diff_pct > tolerance: + failures.append((config_key, expected_time, actual_time, diff_pct)) + + return failures + + +@pytest.fixture +def vector_add_env(): + """Standard vector_add kernel environment for regression tests.""" + kernel_string = ''' + extern "C" __global__ void vector_add(float *c, float *a, float *b, int n) { + int i = blockIdx.x * block_size_x + threadIdx.x; + if (i 0 + + # In simulation mode, times should match exactly (very tight tolerance) + failures = compare_timing_results(results, baseline, tolerance=0.001) + assert len(failures) == 0, f"Timing mismatches in simulation: {failures}" + + def test_baseline_file_integrity(self): + """Verify baseline files have required structure. + + This test checks that baseline JSON files contain all the + required fields and have valid data types. + """ + baseline_file = "vector_add_NVIDIA_RTX_A4000.json" + baseline = load_baseline(baseline_file) + + # Check required top-level keys + required_keys = ["device_name", "kernel_name", "tune_params_keys", + "tune_params", "cache"] + for key in required_keys: + assert key in baseline, f"Missing required key: {key}" + + # Verify metadata + assert baseline["device_name"] == "NVIDIA RTX A4000" + assert baseline["kernel_name"] == "vector_add" + + # Check cache entries have required fields + for config_key, entry in baseline["cache"].items(): + assert "time" in entry, f"Missing 'time' in config {config_key}" + assert "times" in entry, f"Missing 'times' in config {config_key}" + assert isinstance(entry["time"], (int, float)), \ + f"Invalid time type in config {config_key}" + assert isinstance(entry["times"], list), \ + f"Invalid times type in config {config_key}" + assert entry["time"] > 0, f"Time must be positive in config {config_key}" + + def test_baseline_config_coverage(self, vector_add_env): + """Verify baseline covers all expected configurations. + + This test ensures the baseline file contains data for all + the block sizes we're testing. + """ + baseline_file = "vector_add_NVIDIA_RTX_A4000.json" + baseline = load_baseline(baseline_file) + + expected_block_sizes = vector_add_env[-1]["block_size_x"] + + for block_size in expected_block_sizes: + config_key = str(block_size) + assert config_key in baseline["cache"], \ + f"Missing baseline data for block_size_x={block_size}" + + def test_timing_values_reasonable(self): + """Verify baseline timing values are in reasonable range. + + Sanity check that the baseline times are within expected + bounds for a simple vector_add kernel. + """ + baseline_file = "vector_add_NVIDIA_RTX_A4000.json" + baseline = load_baseline(baseline_file) + + for config_key, entry in baseline["cache"].items(): + time_ms = entry["time"] + # For a simple 100-element vector_add, times should be < 1 second + assert time_ms < 1000, \ + f"Unreasonably high time {time_ms}ms for config {config_key}" + # Times should be positive and not essentially zero + assert time_ms > 0.001, \ + f"Suspiciously low time {time_ms}ms for config {config_key}" diff --git a/test/test_cuda_functions.py b/test/test_cuda_functions.py index 1dc68652..09d33420 100644 --- a/test/test_cuda_functions.py +++ b/test/test_cuda_functions.py @@ -8,10 +8,14 @@ from .context import skip_if_no_cuda from .test_runners import env # noqa: F401 +# Support both cuda-python < 13 and >= 13 import structures try: - from cuda import cuda -except Exception: - pass + from cuda.bindings import driver as cuda +except ImportError: + try: + from cuda import cuda + except ImportError: + cuda = None @skip_if_no_cuda