diff --git a/README.md b/README.md index 7343e73b..b7021315 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ uv run python scripts/generate_and_eval_single_sample.py dataset_src=huggingface **What you might need to modify** * **`gpu_arch`** - Depend on your GPU, you might need to adjust the `gpu_arch` argument to reflect your hardware. * **`precision`** - You can specify the precision of tensor by `precision=fp32`. Currently all of our reported results are `fp32` but we added support for `fp16` & `bf16`. -* **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`, `thunderkittens`. +* **`backend`** - We are also supporting GPU programming languages beyond `cuda`, e.g. simply specify `backend=triton` or `backend=hip`. For now we support: `cuda`, `hip`, `triton`, `cute`, `tilelang`. Note on setting up ThunderKittens (TK) locally: to use `backend=thunderkittens`, you need to git clone the ThunderKittens repo and set the following environment variable to point to your local ThunderKittens directory, `export THUNDERKITTENS_ROOT=`, and all ThunderKitten programs as shown in the [example](src/kernelbench/prompts/model_new_ex_add_thunderkittens.py), should contain `tk_root = os.environ.get("THUNDERKITTENS_ROOT", "/root/ThunderKittens")`, which enable the kernel to include the right TK primitives. In addition, we only support BF16 for TK right now. diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index fce1b16f..26854cd2 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -174,7 +174,7 @@ def main(config: EvalConfig): include_hardware = include_hardware.lower() in ["true", "1", "yes"] config.include_hardware_info = include_hardware - supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens"} + supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens", "hip"} backend = config.backend.lower() if backend not in supported_backends: raise ValueError( @@ -278,4 +278,4 @@ def main(config: EvalConfig): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index 783076d9..a21a9959 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -207,7 +207,7 @@ def main(config: EvalConfig): include_hardware = include_hardware.lower() in ["true", "1", "yes"] config.include_hardware_info = include_hardware - supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens"} + supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens", "hip"} backend = config.backend.lower() if backend not in supported_backends: raise ValueError( @@ -298,4 +298,4 @@ def main(config: EvalConfig): f.write(str(kernel_exec_result)) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index 2c01ee8d..27e80fa1 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -234,7 +234,7 @@ def main(config: GenerationConfig): include_hardware = include_hardware.lower() in ["true", "1", "yes"] config.include_hardware_info = include_hardware - supported_backends = {"cuda", "triton", "cute", "tilelang", "thunderkittens"} + supported_backends = {"cuda", "triton", "cute", "tilelang", "thunderkittens", "hip"} backend = config.backend.lower() if backend not in supported_backends: raise ValueError( diff --git a/src/kernelbench/eval.py b/src/kernelbench/eval.py index dd79b2c0..886d5732 100644 --- a/src/kernelbench/eval.py +++ b/src/kernelbench/eval.py @@ -374,7 +374,7 @@ def _process_input_tensor(input, device, backend="cuda", precision=torch.float32 Args: input: Input tensor or non-tensor value device: Target CUDA device - backend: Backend type (e.g., 'cuda', 'triton', 'cute') + backend: Backend type (e.g., 'cuda', `hip`, 'triton', 'cute') precision: torch.dtype Returns: Processed tensor on correct device with correct dtype, or original value if not a tensor @@ -404,7 +404,7 @@ def eval_kernel_against_ref( device: Union[torch.device, int] = ( torch.cuda.current_device() if torch.cuda.is_available() else None ), # have to run on GPU - backend: str = "cuda", # can be 'cuda', 'triton', 'tilelang', or 'cute' + backend: str = "cuda", # can be 'cuda', 'hip', 'triton', 'tilelang', or 'cute' precision: torch.dtype = torch.float32, # Guard against potential reward hacking [optional but ongoing enhancement] @@ -420,7 +420,7 @@ def eval_kernel_against_ref( num_correct_trials: number of trials to initialize different random inputs; correctness pass only if all trials pass num_perf_trials: run the evalutation many times to take the average device: GPU (cuda) device to run the evalutation on - backend: str, one of 'cuda', 'triton', 'tilelang', or 'cute' + backend: str, one of 'cuda', 'hip', 'triton', 'tilelang', or 'cute' precision: torch.dtype for computation (note: tilelang only supports fp16) timing_method: str, method to time kernel, see timing.py for more details @@ -503,7 +503,7 @@ def eval_kernel_against_ref( custom_model_src, entry_point="ModelNew" ) else: - # Default CUDA backend + # Default CUDA/HIP backend ModelNew = load_custom_model(custom_model_src, context, build_dir) torch.cuda.synchronize(device=device) # not sure if this is too much except Exception as e: diff --git a/src/kernelbench/prompts/hardware/gpu_specs.py b/src/kernelbench/prompts/hardware/gpu_specs.py index 800f20ef..53ca7a43 100644 --- a/src/kernelbench/prompts/hardware/gpu_specs.py +++ b/src/kernelbench/prompts/hardware/gpu_specs.py @@ -118,7 +118,91 @@ "Maximum number of thread blocks per SM": "32", "Shared memory capacity per SM": "164 KB", "Maximum shared memory per thread block": "163 KB", - } + }, + "MI300X": { + "GPU Architecture": "gfx942", + "GPU Memory": "192GB", + "Memory Bandwidth": "5.3 TB/s", + "FP64 TFLOPS": "81.7", + "FP64 Matrix Core TFLOPS": "163.4", + "FP32 TFLOPS": "163.4", + "TF32 Matrix Core TFLOPS": "653.7 (1307.4 with sparsity)", + "BFLOAT16 Matrix Core TFLOPS": "1307.4 (2614.9 with sparsity)", + "FP16 Matrix Core TFLOPS": "1307.4 (2614.9 with sparsity)", + "FP8 Matrix Core TFLOPS": "2614.9 (5229.8 with sparsity)", + "INT8 Matrix Core TOPS": "2614.9 (5229.8 with sparsity)", + "Number of CU": "304", + "SIMDs per CU": "4", + "Wavefront Size": "64", + "Workgroup Max Size": "1024", + "Max Waves Per CU": "32", + "Max Threads per CU": "2048", + "Maximum number of registers per thread": "256", + "Shared memory capacity per CU": "64 KB", + }, + "MI325X": { + "GPU Architecture": "gfx942", + "GPU Memory": "256GB", + "Memory Bandwidth": "6TB/s", + "FP64 TFLOPS": "81.7", + "FP64 Matrix Core TFLOPS": "163.4", + "FP32 TFLOPS": "163.4", + "TF32 Matrix Core TFLOPS": "653.7 (1307.4 with sparsity)", + "BFLOAT16 Matrix Core TFLOPS": "1307.4 (2614.9 with sparsity)", + "FP16 Matrix Core TFLOPS": "1307.4 (2614.9 with sparsity)", + "FP8 Matrix Core TFLOPS": "2614.9 (5229.8 with sparsity)", + "INT8 Matrix Core TOPS": "2614.9 (5229.8 with sparsity)", + "Number of CU": "304", + "SIMDs per CU": "4", + "Wavefront Size": "64", + "Workgroup Max Size": "1024", + "Max Waves Per CU": "32", + "Max Threads per CU": "2048", + "Maximum number of registers per thread": "256", + "Shared memory capacity per CU": "64 KB", + }, + "MI350X": { + "GPU Architecture": "gfx950", + "GPU Memory": "288GB", + "Memory Bandwidth": "8TB/s", + "FP64 TFLOPS": "72.1", + "FP64 Matrix Core TFLOPS": "72.1", + "FP32 TFLOPS": "144.2", + "BFLOAT16 Matrix Core TFLOPS": "2300 (4600 with sparsity)", + "FP16 Matrix Core TFLOPS": "2300 (4600 with sparsity)", + "FP8 Matrix Core TFLOPS": "4600", + "MXFP6, MXFP4 Matrix Core TFLOPS": "9200", + "INT8 Matrix Core TOPS": "4600 (9200 with sparsity)", + "Number of CU": "256", + "SIMDs per CU": "4", + "Wavefront Size": "64", + "Workgroup Max Size": "1024", + "Max Waves Per CU": "32", + "Max Threads per CU": "2048", + "Maximum number of registers per thread": "256", + "Shared memory capacity per CU": "160 KB", + }, + "MI355X": { + "GPU Architecture": "gfx950", + "GPU Memory": "288GB", + "Memory Bandwidth": "8TB/s", + "FP64 TFLOPS": "78.6", + "FP64 Matrix Core TFLOPS": "78.6", + "FP32 TFLOPS": "157.3", + "BFLOAT16 Matrix Core TFLOPS": "2500 (5000 with sparsity)", + "FP16 Matrix Core TFLOPS": "2500 (5000 with sparsity)", + "FP8 Matrix Core TFLOPS": "5000", + "MXFP6, MXFP4 Matrix Core TFLOPS": "10000", + "INT8 Matrix Core TOPS": "5000 (10000 with sparsity)", + "Number of CU": "256", + "SIMDs per CU": "4", + "Wavefront Size": "64", + "Workgroup Max Size": "1024", + "Max Waves Per CU": "32", + "Max Threads per CU": "2048", + "Maximum number of registers per thread": "256", + "Shared memory capacity per CU": "160 KB", + }, } # Basic GPU concept definitions diff --git a/src/kernelbench/prompts/model_new_ex_add_hip.py b/src/kernelbench/prompts/model_new_ex_add_hip.py new file mode 100644 index 00000000..fa66cf03 --- /dev/null +++ b/src/kernelbench/prompts/model_new_ex_add_hip.py @@ -0,0 +1,45 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.cpp_extension import load_inline + +import os +os.environ["CXX"] = "hipcc" + +elementwise_add_cpp_source = """ +#include + +__global__ void elementwise_add_kernel(const float* a, const float* b, float* out, int size) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + out[idx] = a[idx] + b[idx]; + } +} + +torch::Tensor elementwise_add_hip(torch::Tensor a, torch::Tensor b) { + auto size = a.numel(); + auto out = torch::zeros_like(a); + + const int block_size = 256; + const int num_blocks = (size + block_size - 1) / block_size; + + elementwise_add_kernel<<>>(a.data_ptr(), b.data_ptr(), out.data_ptr(), size); + + return out; +} +""" + +elementwise_add = load_inline( + name="elementwise_add", + cpp_sources=elementwise_add_cpp_source, + functions=["elementwise_add_hip"], + verbose=True, +) + +class ModelNew(nn.Module): + def __init__(self) -> None: + super().__init__() + self.elementwise_add = elementwise_add + + def forward(self, a, b): + return self.elementwise_add.elementwise_add_hip(a, b) \ No newline at end of file