diff --git a/3rdparty/tvm b/3rdparty/tvm index 41edb06ed..7fd5859f4 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 41edb06ed039944978c671afbd2dde5f22667c83 +Subproject commit 7fd5859f4d32a0c98db2e1186cb0ba1606bd6607 diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 32ff07132..940f1b4ac 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -147,7 +147,7 @@ def remove_tvm_path(path): ApplyDefaultSchedule, # noqa: F401 ApplyFastTuning, # noqa: F401 ) -from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401 +from .utils import auto_detect_target, auto_detect_nvidia_target, apply_transform_on_input # noqa: F401 from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401 from .ops.general_flashatten import FlashAttenConfig, FlashAtten # noqa: F401 diff --git a/bitblas/base/arch/__init__.py b/bitblas/base/arch/__init__.py index dd931f617..8c0099ca0 100644 --- a/bitblas/base/arch/__init__.py +++ b/bitblas/base/arch/__init__.py @@ -6,6 +6,7 @@ from .cdna import CDNA from typing import Union from tvm.target import Target +from bitblas.utils.target_detector import auto_detect_target def get_arch(target: Union[str, Target] = "cuda") -> TileDevice: @@ -22,12 +23,6 @@ def get_arch(target: Union[str, Target] = "cuda") -> TileDevice: raise ValueError(f"Unsupported target: {target.kind.name}") -def auto_infer_current_arch() -> TileDevice: - # TODO(lei): This is a temporary solution to infer the current architecture - # Can be replaced by a more sophisticated method in the future - return get_arch("cuda") - - from .cpu import is_cpu_arch # noqa: F401 from .cuda import ( is_cuda_arch, # noqa: F401 @@ -38,4 +33,9 @@ def auto_infer_current_arch() -> TileDevice: is_tensorcore_supported_precision, # noqa: F401 has_mma_support, # noqa: F401 ) -from .cdna import is_cdna_arch # noqa: F401 +from .cdna import is_cdna_arch, is_matrixcore_supported_precision # noqa: F401 + + +def auto_infer_current_arch() -> TileDevice: + target = auto_detect_target() + return get_arch(target) diff --git a/bitblas/base/arch/cdna.py b/bitblas/base/arch/cdna.py index cb49041db..aa7c3ec72 100644 --- a/bitblas/base/arch/cdna.py +++ b/bitblas/base/arch/cdna.py @@ -11,6 +11,21 @@ def is_cdna_arch(arch: TileDevice) -> bool: return isinstance(arch, CDNA) +# AMD Matrix Core Configurations +cdna_matrixcore_supported = [ + ("float16", "float32"), + ("int8", "int32"), +] + + +def is_matrixcore_supported_precision(in_dtype: str, accum_dtype: str, arch: TileDevice) -> bool: + + if is_cdna_arch(arch): + return (in_dtype, accum_dtype) in cdna_matrixcore_supported + else: + raise ValueError(f"Unsupported architecture: {arch}") + + class CDNA(TileDevice): def __init__(self, target: Union[Target, str]): diff --git a/bitblas/base/arch/cuda.py b/bitblas/base/arch/cuda.py index 5e8730d67..05704afde 100644 --- a/bitblas/base/arch/cuda.py +++ b/bitblas/base/arch/cuda.py @@ -18,40 +18,35 @@ def is_cuda_arch(arch: TileDevice) -> bool: def is_volta_arch(arch: TileDevice) -> bool: conditions = [True] - conditions.append(is_cuda_arch(arch)) - conditions.append(arch.sm_version >= 70) - conditions.append(arch.sm_version < 80) + conditions.append(is_cuda_arch(arch) and arch.sm_version >= 70 and arch.sm_version < 80) return all(conditions) def is_ampere_arch(arch: TileDevice) -> bool: conditions = [True] - conditions.append(is_cuda_arch(arch)) - conditions.append(arch.sm_version >= 80 and arch.sm_version < 90) + conditions.append(is_cuda_arch(arch) and arch.sm_version >= 80 and arch.sm_version < 90) return all(conditions) def is_ada_arch(arch: TileDevice) -> bool: conditions = [True] - conditions.append(is_cuda_arch(arch)) - conditions.append(arch.sm_version == 89) + conditions.append(is_cuda_arch(arch) and arch.sm_version == 89) return all(conditions) def is_hopper_arch(arch: TileDevice) -> bool: conditions = [True] - conditions.append(is_cuda_arch(arch)) - conditions.append(arch.sm_version == 90) + conditions.append(is_cuda_arch(arch) and arch.sm_version == 90) return all(conditions) def has_mma_support(arch: TileDevice) -> bool: conditions = [True] - conditions.append(is_cuda_arch(arch)) - conditions.append(arch.sm_version >= 80) + conditions.append(is_cuda_arch(arch) and arch.sm_version >= 80) return all(conditions) +# NVIDIA Tensor Core Configurations volta_tensorcore_supported = [ ("float16", "float32"), ("float16", "float16"), diff --git a/bitblas/benchmark/operator/__init__.py b/bitblas/benchmark/operator/__init__.py index 7c21d9d0c..b8ebc2a81 100644 --- a/bitblas/benchmark/operator/__init__.py +++ b/bitblas/benchmark/operator/__init__.py @@ -7,7 +7,7 @@ from typing import Dict, List, Tuple, Optional from bitblas.ops import Operator, OperatorConfig from bitblas.utils import get_default_cache_path -from bitblas import auto_detect_nvidia_target +from bitblas import auto_detect_target from bitblas import tvm as tvm from bitblas.cache import OperatorCache import logging @@ -21,7 +21,7 @@ class BitblasOperatorBenchmarkBase(ABC): benchmark_sets: Dict[str, List[Tuple[Operator, OperatorConfig, Optional[int]]]] = {} # Currently we only support NVIDIA target for benchmarking - benchmark_target: str = auto_detect_nvidia_target() + benchmark_target: str = auto_detect_target() # Benchmark results: a list of tuples, each containing latency and tuning time benchmark_results: Dict[str, List[Tuple[Optional[float], Optional[float]]]] = {} diff --git a/bitblas/builder/wrapper/tl.py b/bitblas/builder/wrapper/tl.py index 85a75601a..a72ec38a7 100644 --- a/bitblas/builder/wrapper/tl.py +++ b/bitblas/builder/wrapper/tl.py @@ -87,6 +87,9 @@ def get_cuda_init_func(self): init_funcs = PREDEF_INIT_FUNC.format(call_str) return init_funcs + def get_stream_argument(self) -> Dict: + return {"name": "stream=cudaStreamDefault", "type": "cudaStream_t"} + def update_lib_code(self, code: str): # Update the library code with the given code string self.lib_code = code @@ -115,7 +118,7 @@ def update_lib_code(self, code: str): for dyn_sym in dynamic_symbolic_set: function_args.append({"name": dyn_sym, "type": "int"}) - function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) + function_args.append(self.get_stream_argument()) # Format the function arguments for declaration def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) @@ -223,7 +226,7 @@ def create_dispatch_func(self, code, function_informations): for dyn_sym in dynamic_symbolic_set: function_args.append({"name": dyn_sym, "type": "int"}) - function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) + function_args.append(self.get_stream_argument()) # Format the argument definitions for function declaration def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) @@ -392,8 +395,8 @@ def get_hip_init_func(self): init_funcs = PREDEF_INIT_FUNC.format(call_str) return init_funcs - def get_stream_type(self, function_args): - function_args.append({"name": "stream=hipStreamDefault", "type": "hipStream_t"},) + def get_stream_argument(self) -> Dict: + return {"name": "stream=hipStreamDefault", "type": "hipStream_t"} class TLWrapper(BaseWrapper): diff --git a/bitblas/cache/operator.py b/bitblas/cache/operator.py index 1a1418758..9965e7eab 100644 --- a/bitblas/cache/operator.py +++ b/bitblas/cache/operator.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import bitblas -from bitblas.utils import get_default_cache_path +from bitblas.utils import get_default_cache_path, auto_detect_target from bitblas.ops.operator import OperatorConfig, Operator from dataclasses import asdict import os @@ -186,7 +186,7 @@ def load_global_ops_cache(database_path=None, target=None): if database_path is None: database_path = get_database_path() if target is None: - target = bitblas.auto_detect_nvidia_target() + target = auto_detect_target() logger.info(f"Loading operators from database {database_path} for target {target}") global_operator_cache.load_from_database(database_path, target) return global_operator_cache diff --git a/bitblas/module/__init__.py b/bitblas/module/__init__.py index a85be8d51..0dc58d19e 100644 --- a/bitblas/module/__init__.py +++ b/bitblas/module/__init__.py @@ -16,7 +16,7 @@ from bitblas.cache import global_operator_cache, get_database_path from bitblas import Matmul, MatmulConfig from bitblas.quantization.utils import general_compress -from bitblas import auto_detect_nvidia_target +from bitblas import auto_detect_target BITBLAS_DATABASE_PATH = get_database_path() @@ -240,7 +240,7 @@ def _configure_bitblas_matmul( self.source_format = self.bitblas_matmul.source_format def _get_or_create_bitblas_operator(self, config, enable_tuning): - BITBLAS_TARGET = auto_detect_nvidia_target() + BITBLAS_TARGET = auto_detect_target() if global_operator_cache.size() == 0: global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET) diff --git a/bitblas/ops/general_flashatten/__init__.py b/bitblas/ops/general_flashatten/__init__.py index 6fc1c9ad1..59c8f42c1 100644 --- a/bitblas/ops/general_flashatten/__init__.py +++ b/bitblas/ops/general_flashatten/__init__.py @@ -6,7 +6,7 @@ from bitblas.base.base_scheduler import BaseScheduler from ..operator import OperatorConfig, Operator, BaseKernelNameGenerator from ...base.arch.cuda import CUDA -from ...utils import auto_detect_nvidia_target +from ...utils import auto_detect_target from dataclasses import dataclass from typing import Union, Tuple, Literal, Optional, Any import logging @@ -93,7 +93,7 @@ def __init__( backend: str = "tl", ): if target is None: - target = auto_detect_nvidia_target() + target = auto_detect_target() logger.info(f"Auto detected target: {target}") assert (config.Q_dtype diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 723391777..68f749157 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -16,7 +16,7 @@ from .tilelang.dequantize import select_scheduler as weight_dequantize_scheduler from ...base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 from bitblas.utils import retrieve_func_from_module -from bitblas.utils.target_detector import auto_detect_nvidia_target +from bitblas.utils.target_detector import auto_detect_target from dataclasses import dataclass from ..ladder_permutate import LadderPermutate, LadderPermutateConfig from ..quant_compress import QuantCompress, QuantCompressConfig @@ -356,7 +356,7 @@ def __init__( # if from database, we should disable default schedule # to save compilation time if target is None: - target = auto_detect_nvidia_target() + target = auto_detect_target() logger.info(f"Auto detected target: {target}") assert (config.A_dtype diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul.py index 8b197e569..40526868b 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul.py @@ -5,12 +5,8 @@ from tvm.tir import PrimFunc from bitblas.base.operator_common import TransformKind from bitblas.base.base_scheduler import BaseScheduler -from bitblas.base.arch import ( - TileDevice, - is_ampere_arch, - is_volta_arch, - is_tensorcore_supported_precision, -) +from bitblas.base.arch import (TileDevice, is_ampere_arch, is_volta_arch, is_cdna_arch, + is_tensorcore_supported_precision, is_matrixcore_supported_precision) from dataclasses import dataclass from bitblas.tl.base_hint import BaseTLHint @@ -128,11 +124,46 @@ def dispatch_volta_scheduler(self, arch: TileDevice) -> BaseScheduler: else: return self.matmul_simt_scheduler + def dispatch_cdna_scheduler(self, arch: TileDevice) -> BaseScheduler: + M = self.maybe_dynamic(self.M, "m") + N, K = self.N, self.K + assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" + + is_dynamic = self.is_dynamic + in_dtype, accum_dtype = ( + self.in_dtype, + self.accum_dtype, + ) + if self.weight_transform_kind != TransformKind.NonTransform: + raise ValueError( + f"Weight propagation {self.weight_transform_kind} is not supported for CDNA") + if in_dtype not in ["int8", "float16", "float32", "float64"]: + raise ValueError(f"Unsupported input data type: {in_dtype}") + + if is_dynamic: + # Dynamic Dispatcher + if is_matrixcore_supported_precision(in_dtype, accum_dtype, arch): + return self.matmul_block_scheduler + else: + return self.matmul_simt_scheduler + else: + minimal_tensorcore_threshold: List[int, int, int] = [8, 16, 16] + if minimal_tensorcore_threshold[0] > M or minimal_tensorcore_threshold[ + 1] > N or minimal_tensorcore_threshold[2] > K: + return self.gemv_scheduler + elif is_matrixcore_supported_precision(in_dtype, accum_dtype, arch): + # Fine-grained scheduler (mma) is not implemented for CDNA + return self.matmul_block_scheduler + else: + return self.matmul_simt_scheduler + def dispatch_scheduler(self, arch: TileDevice) -> BaseScheduler: if is_ampere_arch(arch): return self.dispatch_ampere_scheduler(arch) elif is_volta_arch(arch): return self.dispatch_volta_scheduler(arch) + elif is_cdna_arch(arch): + return self.dispatch_cdna_scheduler(arch) else: raise ValueError(f"Unsupported architecture: {arch}") diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py index 9716ac075..bfb7e752e 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py @@ -9,7 +9,9 @@ TileDevice, is_ampere_arch, is_volta_arch, + is_cdna_arch, is_tensorcore_supported_precision, + is_matrixcore_supported_precision, ) from dataclasses import dataclass from bitblas.tl.base_hint import BaseTLHint @@ -143,11 +145,57 @@ def dispatch_volta_scheduler(self, arch: TileDevice) -> BaseScheduler: else: return self.matmul_dequantize_simt_scheduler + def dispatch_cdna_scheduler(self, arch: TileDevice) -> BaseScheduler: + M = self.maybe_dynamic(self.M, "m") + N, K = self.N, self.K + assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" + + is_dynamic = self.is_dynamic + in_dtype, accum_dtype = ( + self.in_dtype, + self.accum_dtype, + ) + weight_transform_kind = self.weight_transform_kind + if is_dynamic: + # Dynamic Dispatcher + if is_matrixcore_supported_precision(in_dtype, accum_dtype, arch): + if weight_transform_kind != TransformKind.NonTransform: + raise NotImplementedError("Weight propagation is not supported for MatrixCore with Dequantization") + else: + raise NotImplementedError("Fine-grained scheduler is not supported for MatrixCore with Dequantization")s + return self.matmul_dequantize_fine_grained_scheduler + else: + if weight_transform_kind != TransformKind.NonTransform: + raise ValueError( + "Weight propagation is not supported for non-TensorCore architectures") + return self.matmul_dequantize_simt_scheduler + else: + minimal_tensorcore_threshold: List[int, int, int] = ([8, 16, 32] if accum_dtype + == "int32" else [8, 16, 16]) + if (minimal_tensorcore_threshold[0] > M or minimal_tensorcore_threshold[1] > N or + minimal_tensorcore_threshold[2] > K): + if in_dtype == "int4": + raise ValueError("INT4 is not supported for non-TensorCore architectures") + if weight_transform_kind != TransformKind.NonTransform: + raise ValueError( + "Weight propagation is not supported for non-TensorCore architectures") + return self.gemv_dequantize_simt_scheduler + elif is_matrixcore_supported_precision(in_dtype, accum_dtype, arch): + if self.weight_transform_kind != TransformKind.NonTransform: + raise NotImplementedError( + "Weight propagation is not supported for MatrixCore with Dequantization") + else: + return self.matmul_dequantize_fine_grained_scheduler + else: + return self.matmul_dequantize_simt_scheduler + def dispatch_scheduler(self, arch: TileDevice) -> BaseScheduler: if is_ampere_arch(arch): return self.dispatch_ampere_scheduler(arch) elif is_volta_arch(arch): return self.dispatch_volta_scheduler(arch) + elif is_cdna_arch(arch): + return self.dispatch_cdna_scheduler(arch) else: raise ValueError(f"Unsupported architecture: {arch}") diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index ea99490c5..0bd8c3b3c 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -278,7 +278,7 @@ def _build_default_module(self, target: Target): self._update_optimized_mod(scheduled_ir_module) except Exception as apply_schedule_error: self.scheduled_ir_module = None - logger.warning( + logger.exception( APPLY_SCHEDULE_FAILED_MESSAGE.format(self.__class__.__name__, target, "default", apply_schedule_error)) diff --git a/bitblas/utils/__init__.py b/bitblas/utils/__init__.py index e94b178aa..b35ba5dab 100644 --- a/bitblas/utils/__init__.py +++ b/bitblas/utils/__init__.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 # noqa: F401 from .tensor_adapter import tvm_tensor_to_torch, lazy_tvm_tensor_to_torch, lazy_torch_to_tvm_tensor # noqa: F401 -from .target_detector import get_all_nvidia_targets, auto_detect_nvidia_target # noqa: F401 +from .target_detector import auto_detect_target, get_all_nvidia_targets, auto_detect_nvidia_target # noqa: F401 from .rtmod_analysis import get_annotated_device_mod # noqa: F401 from .weight_propagate import apply_transform_on_input # noqa: F401 diff --git a/bitblas/utils/target_detector.py b/bitblas/utils/target_detector.py index 71d6dcc1f..e756ee84a 100644 --- a/bitblas/utils/target_detector.py +++ b/bitblas/utils/target_detector.py @@ -23,6 +23,7 @@ "NVIDIA PG506-232": "NVIDIA A100", } + def get_gpu_model_from_nvidia_smi(gpu_id: int = 0): """ Executes the 'nvidia-smi' command to fetch the name of the first available NVIDIA GPU. @@ -52,6 +53,7 @@ def get_gpu_model_from_nvidia_smi(gpu_id: int = 0): return gpus[gpu_id] + def find_best_match(tags, query): """ Finds the best match for a query within a list of tags using fuzzy string matching. @@ -101,3 +103,48 @@ def auto_detect_nvidia_target(gpu_id: int = 0) -> str: target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda" return target + + +def auto_detect_target(target: str = "auto", gpu_id: int = 0) -> str: + """Detect the computing target (CUDA or ROCm) based on the environment. + + Args: + target (str): The target to detect. Use "auto" for automatic detection. + Can also specify "cuda" or "hip" explicitly. + + Returns: + str: The detected target, either "cuda" or "hip". + + Raises: + ValueError: If auto-detection is enabled and no valid target is found. + """ + + from tvm.contrib import nvcc, rocm + + def is_cuda_available() -> bool: + """Check if CUDA is available.""" + try: + nvcc.find_cuda_path() + return True + except RuntimeError: + return False + + def is_rocm_available() -> bool: + """Check if ROCm is available.""" + try: + rocm.find_rocm_path() + return True + except RuntimeError: + return False + + if target == "auto": + if is_cuda_available(): + return auto_detect_nvidia_target(gpu_id=gpu_id) + if is_rocm_available(): + return "hip" + raise ValueError("Cannot detect the target: no CUDA or ROCm installation found.") + + if target not in {"cuda", "hip"}: + raise ValueError(f"Invalid target: {target}. Must be 'cuda', 'hip', or 'auto'.") + + return target