diff --git a/README.md b/README.md index 7343e73b..a53d5672 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,9 @@ We have transitioned to using `pyproject.toml` and `uv` for dependency managemen # Install base dependencies (works without a local GPU) uv sync +# Install with AMD ROCm backend (ROCm>=7.1 is required) +uv add torch --index pytorch=https://download.pytorch.org/whl/rocm7.1 + # Install with GPU dependencies (for local GPU evaluation) uv sync --extra gpu @@ -115,10 +118,9 @@ uv run python scripts/generate_and_eval_single_sample.py dataset_src=huggingface ``` **What you might need to modify** -* **`gpu_arch`** - Depend on your GPU, you might need to adjust the `gpu_arch` argument to reflect your hardware. +* **`gpu_arch`** - Depend on your GPU, you might need to adjust the `gpu_arch` argument to reflect your hardware. `gpu_arch` currently supported for `hip` backend: `gfx942`, `gfx950`. * **`precision`** - You can specify the precision of tensor by `precision=fp32`. Currently all of our reported results are `fp32` but we added support for `fp16` & `bf16`. -* **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`, `thunderkittens`. - +* **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. For example, simply specify `backend=triton` or `backend=hip`. For now we support DSLs: `cuda`, `hip`, `triton`, `cute`, `tilelang`, `thunderkittens`. Note on setting up ThunderKittens (TK) locally: to use `backend=thunderkittens`, you need to git clone the ThunderKittens repo and set the following environment variable to point to your local ThunderKittens directory, `export THUNDERKITTENS_ROOT=`, and all ThunderKitten programs as shown in the [example](src/kernelbench/prompts/model_new_ex_add_thunderkittens.py), should contain `tk_root = os.environ.get("THUNDERKITTENS_ROOT", "/root/ThunderKittens")`, which enable the kernel to include the right TK primitives. In addition, we only support BF16 for TK right now. diff --git a/pyproject.toml b/pyproject.toml index bed37150..f3f98ae2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,10 +10,10 @@ version = "0.2.0.dev0" requires-python = "==3.10.*" dependencies = [ # Frameworks - "torch==2.9.0", + "torch>=2.9.0", "transformers", - "datasets", + "datasets>=2.19.0", "modal", # helper diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index fce1b16f..95aea5d7 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -124,6 +124,8 @@ def main(config: EvalConfig): ) if config.gpu_arch: + if (type(config.gpu_arch) is not list): + config.gpu_arch = [config.gpu_arch] set_gpu_arch(config.gpu_arch) # otherwise build for all architectures if config.log: @@ -174,7 +176,7 @@ def main(config: EvalConfig): include_hardware = include_hardware.lower() in ["true", "1", "yes"] config.include_hardware_info = include_hardware - supported_backends = {"cuda", "triton", "tilelang", "cute", "thunderkittens"} + supported_backends = {"cuda", "hip", "triton", "tilelang", "cute", "thunderkittens"} backend = config.backend.lower() if backend not in supported_backends: raise ValueError( diff --git a/src/kernelbench/prompts/hardware/gpu_specs.py b/src/kernelbench/prompts/hardware/gpu_specs.py index 800f20ef..ca63488f 100644 --- a/src/kernelbench/prompts/hardware/gpu_specs.py +++ b/src/kernelbench/prompts/hardware/gpu_specs.py @@ -118,6 +118,90 @@ "Maximum number of thread blocks per SM": "32", "Shared memory capacity per SM": "164 KB", "Maximum shared memory per thread block": "163 KB", + }, + "MI300X": { + "GPU Architecture": "gfx942", + "GPU Memory": "192GB", + "Memory Bandwidth": "5.3 TB/s", + "FP64 TFLOPS": "81.7", + "FP64 Matrix Core TFLOPS": "163.4", + "FP32 TFLOPS": "163.4", + "TF32 Matrix Core TFLOPS": "653.7 (1307.4 with sparsity)", + "BFLOAT16 Matrix Core TFLOPS": "1307.4 (2614.9 with sparsity)", + "FP16 Matrix Core TFLOPS": "1307.4 (2614.9 with sparsity)", + "FP8 Matrix Core TFLOPS": "2614.9 (5229.8 with sparsity)", + "INT8 Matrix Core TOPS": "2614.9 (5229.8 with sparsity)", + "Number of CU": "304", + "SIMDs per CU": "4", + "Wavefront Size": "64", + "Workgroup Max Size": "1024", + "Max Waves Per CU": "32", + "Max Threads per CU": "2048", + "Maximum number of registers per thread": "256", + "Shared memory capacity per CU": "64 KB", + }, + "MI325X": { + "GPU Architecture": "gfx942", + "GPU Memory": "256GB", + "Memory Bandwidth": "6TB/s", + "FP64 TFLOPS": "81.7", + "FP64 Matrix Core TFLOPS": "163.4", + "FP32 TFLOPS": "163.4", + "TF32 Matrix Core TFLOPS": "653.7 (1307.4 with sparsity)", + "BFLOAT16 Matrix Core TFLOPS": "1307.4 (2614.9 with sparsity)", + "FP16 Matrix Core TFLOPS": "1307.4 (2614.9 with sparsity)", + "FP8 Matrix Core TFLOPS": "2614.9 (5229.8 with sparsity)", + "INT8 Matrix Core TOPS": "2614.9 (5229.8 with sparsity)", + "Number of CU": "304", + "SIMDs per CU": "4", + "Wavefront Size": "64", + "Workgroup Max Size": "1024", + "Max Waves Per CU": "32", + "Max Threads per CU": "2048", + "Maximum number of registers per thread": "256", + "Shared memory capacity per CU": "64 KB", + }, + "MI350X": { + "GPU Architecture": "gfx950", + "GPU Memory": "288GB", + "Memory Bandwidth": "8TB/s", + "FP64 TFLOPS": "72.1", + "FP64 Matrix Core TFLOPS": "72.1", + "FP32 TFLOPS": "144.2", + "BFLOAT16 Matrix Core TFLOPS": "2300 (4600 with sparsity)", + "FP16 Matrix Core TFLOPS": "2300 (4600 with sparsity)", + "FP8 Matrix Core TFLOPS": "4600", + "MXFP6, MXFP4 Matrix Core TFLOPS": "9200", + "INT8 Matrix Core TOPS": "4600 (9200 with sparsity)", + "Number of CU": "256", + "SIMDs per CU": "4", + "Wavefront Size": "64", + "Workgroup Max Size": "1024", + "Max Waves Per CU": "32", + "Max Threads per CU": "2048", + "Maximum number of registers per thread": "256", + "Shared memory capacity per CU": "160 KB", + }, + "MI355X": { + "GPU Architecture": "gfx950", + "GPU Memory": "288GB", + "Memory Bandwidth": "8TB/s", + "FP64 TFLOPS": "78.6", + "FP64 Matrix Core TFLOPS": "78.6", + "FP32 TFLOPS": "157.3", + "BFLOAT16 Matrix Core TFLOPS": "2500 (5000 with sparsity)", + "FP16 Matrix Core TFLOPS": "2500 (5000 with sparsity)", + "FP8 Matrix Core TFLOPS": "5000", + "MXFP6, MXFP4 Matrix Core TFLOPS": "10000", + "INT8 Matrix Core TOPS": "5000 (10000 with sparsity)", + "Number of CU": "256", + "SIMDs per CU": "4", + "Wavefront Size": "64", + "Workgroup Max Size": "1024", + "Max Waves Per CU": "32", + "Max Threads per CU": "2048", + "Maximum number of registers per thread": "256", + "Shared memory capacity per CU": "160 KB", } } diff --git a/src/kernelbench/prompts/model_new_ex_add_hip.py b/src/kernelbench/prompts/model_new_ex_add_hip.py new file mode 100644 index 00000000..2498bc18 --- /dev/null +++ b/src/kernelbench/prompts/model_new_ex_add_hip.py @@ -0,0 +1,45 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.cpp_extension import load_inline + +os.environ["CXX"] = "hipcc" + +elementwise_add_cpp_source = """ +#include + +__global__ void elementwise_add_kernel(const float* a, const float* b, float* out, int size) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + out[idx] = a[idx] + b[idx]; + } +} + +torch::Tensor elementwise_add_hip(torch::Tensor a, torch::Tensor b) { + auto size = a.numel(); + auto out = torch::zeros_like(a); + + const int block_size = 256; + const int num_blocks = (size + block_size - 1) / block_size; + + elementwise_add_kernel<<>>(a.data_ptr(), b.data_ptr(), out.data_ptr(), size); + + return out; +} +""" + +elementwise_add = load_inline( + name="elementwise_add", + cpp_sources=elementwise_add_cpp_source, + functions=["elementwise_add_hip"], + verbose=True, +) + +class ModelNew(nn.Module): + def __init__(self) -> None: + super().__init__() + self.elementwise_add = elementwise_add + + def forward(self, a, b): + return self.elementwise_add.elementwise_add_hip(a, b) \ No newline at end of file diff --git a/src/kernelbench/prompts/prompts.toml b/src/kernelbench/prompts/prompts.toml index 2768aa11..61b6b15f 100644 --- a/src/kernelbench/prompts/prompts.toml +++ b/src/kernelbench/prompts/prompts.toml @@ -54,6 +54,11 @@ backend_display = "ThunderKittens kernels" one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_add_thunderkittens.py" # No few_shot_examples - will use one-shot when few_shot option is selected +[backends.hip] +backend_display = "HIP kernels" +one_shot_new_arch = "src/kernelbench/prompts/model_new_ex_add_hip.py" +# No few_shot_examples - will use one-shot when few_shot option is selected + # ------------------------------------------------------------------------- # Precision: Precision-specific configuration # ------------------------------------------------------------------------- diff --git a/src/kernelbench/utils.py b/src/kernelbench/utils.py index cf8b0ad8..bbd6a468 100644 --- a/src/kernelbench/utils.py +++ b/src/kernelbench/utils.py @@ -42,7 +42,7 @@ def set_gpu_arch(arch_list: list[str]): """ Set env variable for torch cuda arch list to build kernels for specified architectures """ - valid_archs = ["Maxwell", "Pascal", "Volta", "Turing", "Ampere", "Hopper", "Ada"] + valid_archs = ["Maxwell", "Pascal", "Volta", "Turing", "Ampere", "Hopper", "Ada", "gfx942", "gfx950"] for arch in arch_list: if arch not in valid_archs: raise ValueError(f"Invalid architecture: {arch}. Must be one of {valid_archs}")