diff --git a/3rdparty/tilelang b/3rdparty/tilelang index 6aef1f896..b09e2b5cc 160000 --- a/3rdparty/tilelang +++ b/3rdparty/tilelang @@ -1 +1 @@ -Subproject commit 6aef1f8968bb3f8f806b74eb1334eb2a44a9ab3a +Subproject commit b09e2b5cc6abfe94c35249cb99ad899ef394964e diff --git a/3rdparty/tvm b/3rdparty/tvm index b372d9ca2..d310bd5aa 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit b372d9ca2159a1afd5439990f68bfa29578a8bac +Subproject commit d310bd5aadce96145546fb7a87a6d325ea392b2b diff --git a/bitblas/base/arch/cuda.py b/bitblas/base/arch/cuda.py index 5e8730d67..25c83bff1 100644 --- a/bitblas/base/arch/cuda.py +++ b/bitblas/base/arch/cuda.py @@ -27,7 +27,7 @@ def is_volta_arch(arch: TileDevice) -> bool: def is_ampere_arch(arch: TileDevice) -> bool: conditions = [True] conditions.append(is_cuda_arch(arch)) - conditions.append(arch.sm_version >= 80 and arch.sm_version < 90) + conditions.append(arch.sm_version >= 80 and arch.sm_version < 89) return all(conditions) diff --git a/bitblas/builder/wrapper/tl.py b/bitblas/builder/wrapper/tl.py index c64ca2b8d..5c635fed9 100644 --- a/bitblas/builder/wrapper/tl.py +++ b/bitblas/builder/wrapper/tl.py @@ -18,8 +18,8 @@ class TLCUDASourceWrapper(object): _TYPE_MAP = { "float32": "float", - "float16": "half", - "bfloat16": "__nv_bfloat16", + "float16": "half_t", + "bfloat16": "bfloat16_t", "e4m3_float8": "__nv_fp8_e4m3", "e5m2_float8": "__nv_fp8_e5m2", "float64": "double", diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index d43d95ffa..607f81fff 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -649,7 +649,9 @@ def check_last_trait(region: List[Range]): if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70: in_dtype, out_dtype = get_in_out_dtypes(block_stmt) if not is_tensorcore_supported_precision(in_dtype, out_dtype, arch=get_arch(target)): - logger.debug("The input and output dtype is not supported by tensorcore") + logger.debug( + f"The input and output dtype ({in_dtype}, {out_dtype})is not supported by tensorcore" + ) return func, None # reindex and transform functions diff --git a/bitblas/ops/general_matmul/tilelang/dense/__init__.py b/bitblas/ops/general_matmul/tilelang/dense/__init__.py index 17843fc0f..888bbf2c9 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dense/__init__.py @@ -102,7 +102,7 @@ def ampere_select_scheduler( trans_A, trans_B = parse_layout(layout) - def can_apply_fine_grain_scheduler(trans_A, trans_B, propagate_a, propagate_b): + def can_apply_mma_scheduler(trans_A, trans_B, propagate_a, propagate_b): conditions = [] conditions.append(trans_A is False) conditions.append(trans_B is True) @@ -116,7 +116,7 @@ def can_apply_block_scheduler(propagate_a, propagate_b): conditions.append(propagate_b == TransformKind.NonTransform) return all(conditions) - def can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b): + def can_apply_mma_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b): conditions = [] conditions.append(trans_A is False) conditions.append(trans_B is True) @@ -127,7 +127,7 @@ def can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propag def is_int4_dtype(dtype): return dtype == "int4" or dtype == "uint4" - if can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b): + if can_apply_mma_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b): Scheduler = MatmulMMAWeightPropagationScheduler if not is_int4_dtype( in_dtype) else MatmulINT4MMAWeightPropagationScheduler return Scheduler( @@ -141,7 +141,7 @@ def is_int4_dtype(dtype): accum_dtype=accum_dtype, with_bias=with_bias, ) - if can_apply_fine_grain_scheduler(trans_A, trans_B, propagate_a, propagate_b): + if can_apply_mma_scheduler(trans_A, trans_B, propagate_a, propagate_b): Scheduler = MatmulMMAScheduler if not is_int4_dtype(in_dtype) else MatmulINT4MMAScheduler return Scheduler( M=M, diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul.py index d18a77c8d..ef1bfed37 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul.py @@ -9,8 +9,11 @@ TileDevice, is_ampere_arch, is_volta_arch, + is_ada_arch, + is_hopper_arch, is_tensorcore_supported_precision, ) +from tilelang.intrinsics.utils import get_mma_micro_size from dataclasses import dataclass from bitblas.tl.base_hint import BaseTLHint @@ -39,20 +42,20 @@ class MatmulScheduler(MatmulBaseParams): gemv_scheduler: Optional[GemvFineGrainSIMTScheduler] = None matmul_simt_scheduler: Optional[MatmulFineGrainSIMTScheduler] = None matmul_block_scheduler: Optional[MatmulTileLibraryScheduler] = None - matmul_fine_grain_scheduler: Optional[MatmulMMAScheduler] = None - matmul_weight_propagation_scheduler: Optional[MatmulMMAWeightPropagationScheduler] = None - matmul_int4_fine_grain_scheduler: Optional[MatmulINT4MMAScheduler] = None - matmul_int4_weight_propagation_scheduler: Optional[ + matmul_mma_scheduler: Optional[MatmulMMAScheduler] = None + matmul_mma_weight_propagation_scheduler: Optional[MatmulMMAWeightPropagationScheduler] = None + matmul_int4_mma_scheduler: Optional[MatmulINT4MMAScheduler] = None + matmul_int4_mma_weight_propagation_scheduler: Optional[ MatmulINT4MMAWeightPropagationScheduler] = None def __init__(self, **kwargs): self.gemv_scheduler = GemvFineGrainSIMTScheduler(**kwargs) self.matmul_simt_scheduler = MatmulFineGrainSIMTScheduler(**kwargs) self.matmul_block_scheduler = MatmulTileLibraryScheduler(**kwargs) - self.matmul_fine_grain_scheduler = MatmulMMAScheduler(**kwargs) - self.matmul_weight_propagation_scheduler = MatmulMMAWeightPropagationScheduler(**kwargs) - self.matmul_int4_fine_grain_scheduler = MatmulINT4MMAScheduler(**kwargs) - self.matmul_int4_weight_propagation_scheduler = MatmulINT4MMAWeightPropagationScheduler( + self.matmul_mma_scheduler = MatmulMMAScheduler(**kwargs) + self.matmul_mma_weight_propagation_scheduler = MatmulMMAWeightPropagationScheduler(**kwargs) + self.matmul_int4_mma_scheduler = MatmulINT4MMAScheduler(**kwargs) + self.matmul_int4_mma_weight_propagation_scheduler = MatmulINT4MMAWeightPropagationScheduler( **kwargs) super().__init__(**kwargs) @@ -72,14 +75,13 @@ def dispatch_ampere_scheduler(self, arch: TileDevice) -> BaseScheduler: if is_tensorcore_supported_precision(in_dtype, accum_dtype, arch): if weight_transform_kind != TransformKind.NonTransform: # INT4 Can be fused into general dequantize - return self.matmul_int4_weight_propagation_scheduler if in_dtype == "int4" else self.matmul_weight_propagation_scheduler - return self.matmul_int4_fine_grain_scheduler if in_dtype == "int4" else self.matmul_fine_grain_scheduler + return self.matmul_int4_mma_weight_propagation_scheduler if in_dtype == "int4" else self.matmul_mma_weight_propagation_scheduler + return self.matmul_int4_mma_scheduler if in_dtype == "int4" else self.matmul_mma_scheduler else: return self.matmul_simt_scheduler else: - minimal_tensorcore_threshold: List[int, int, - int] = [8, 16, 32 - ] if accum_dtype == "int32" else [8, 16, 16] + _, _, micro_size_k = get_mma_micro_size(in_dtype) + minimal_tensorcore_threshold: List[int, int, int] = [8, 16, micro_size_k] if minimal_tensorcore_threshold[0] > M or minimal_tensorcore_threshold[ 1] > N or minimal_tensorcore_threshold[2] > K: if in_dtype == "int4": @@ -90,10 +92,11 @@ def dispatch_ampere_scheduler(self, arch: TileDevice) -> BaseScheduler: return self.gemv_scheduler elif is_tensorcore_supported_precision(in_dtype, accum_dtype, arch): if self.weight_transform_kind != TransformKind.NonTransform: - return (self.matmul_int4_weight_propagation_scheduler - if in_dtype == "int4" else self.matmul_weight_propagation_scheduler) + return (self.matmul_int4_mma_weight_propagation_scheduler + if in_dtype == "int4" else self.matmul_mma_weight_propagation_scheduler) else: - return self.matmul_int4_fine_grain_scheduler if in_dtype == "int4" else self.matmul_block_scheduler + # by default, use the mma_scheduler + return self.matmul_int4_mma_scheduler if in_dtype == "int4" else self.matmul_mma_scheduler else: return self.matmul_simt_scheduler @@ -131,7 +134,10 @@ def dispatch_volta_scheduler(self, arch: TileDevice) -> BaseScheduler: return self.matmul_simt_scheduler def dispatch_scheduler(self, arch: TileDevice) -> BaseScheduler: - if is_ampere_arch(arch): + if is_hopper_arch(arch): + logger.warning("Hopper architecture is not fully supported yet, fallback to Ada") + return self.dispatch_ampere_scheduler(arch) + elif is_ampere_arch(arch) or is_ada_arch(arch): return self.dispatch_ampere_scheduler(arch) elif is_volta_arch(arch): return self.dispatch_volta_scheduler(arch) @@ -143,10 +149,10 @@ def detect_scheduler_from_hint(self, hint: BaseTLHint) -> BaseScheduler: self.gemv_scheduler, self.matmul_simt_scheduler, self.matmul_block_scheduler, - self.matmul_fine_grain_scheduler, - self.matmul_weight_propagation_scheduler, - self.matmul_int4_fine_grain_scheduler, - self.matmul_int4_weight_propagation_scheduler, + self.matmul_mma_scheduler, + self.matmul_mma_weight_propagation_scheduler, + self.matmul_int4_mma_scheduler, + self.matmul_int4_mma_weight_propagation_scheduler, ]: try: scheduler_hint_type = scheduler.get_hint_type() @@ -213,10 +219,10 @@ def set_dynamic_range(self, dynamic_range: Dict[str, int]) -> "BaseScheduler": self.gemv_scheduler, self.matmul_simt_scheduler, self.matmul_block_scheduler, - self.matmul_fine_grain_scheduler, - self.matmul_weight_propagation_scheduler, - self.matmul_int4_fine_grain_scheduler, - self.matmul_int4_weight_propagation_scheduler, + self.matmul_mma_scheduler, + self.matmul_mma_weight_propagation_scheduler, + self.matmul_int4_mma_scheduler, + self.matmul_int4_mma_weight_propagation_scheduler, ]: scheduler.set_dynamic_range(dynamic_range) return self @@ -227,10 +233,10 @@ def with_arch(self, arch): self.gemv_scheduler, self.matmul_simt_scheduler, self.matmul_block_scheduler, - self.matmul_fine_grain_scheduler, - self.matmul_weight_propagation_scheduler, - self.matmul_int4_fine_grain_scheduler, - self.matmul_int4_weight_propagation_scheduler, + self.matmul_mma_scheduler, + self.matmul_mma_weight_propagation_scheduler, + self.matmul_int4_mma_scheduler, + self.matmul_int4_mma_weight_propagation_scheduler, ]: scheduler.with_arch(arch) return self diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py index f36c05663..531035647 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py @@ -7,7 +7,7 @@ from tvm import DataType import tilelang.language as T from typing import Optional, List -from bitblas.tl.utils import ( +from tilelang.intrinsics.utils import ( get_mma_micro_size, make_mma_swizzle_layout as make_swizzle_layout, ) diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py index a906bc308..763184e90 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py @@ -6,7 +6,7 @@ from bitblas import tilelang as tilelang import tilelang.language as T from typing import Optional, List -from bitblas.tl.utils import ( +from tilelang.intrinsics.utils import ( get_mma_micro_size, make_mma_swizzle_layout as make_swizzle_layout, ) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py index 503453760..562714fe2 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py @@ -9,8 +9,11 @@ TileDevice, is_ampere_arch, is_volta_arch, + is_ada_arch, + is_hopper_arch, is_tensorcore_supported_precision, ) +from tilelang.intrinsics.utils import get_mma_micro_size from dataclasses import dataclass from bitblas.tl.base_hint import BaseTLHint @@ -39,23 +42,22 @@ class MatmulDequantizeScheduler(MatmulDequantizeBaseParams): gemv_dequantize_simt_scheduler: Optional[GemvDequantizeSIMTScheduler] = None matmul_dequantize_simt_scheduler: Optional[MatmulDequantizeSIMTScheduler] = None matmul_dequantize_block_scheduler: Optional[MatmulDequantizeTileLibraryScheduler] = None - matmul_dequantize_fine_grained_scheduler: Optional[MatmulDequantizeMMAScheduler] = None - matmul_dequantize_weight_propagation_scheduler: Optional[ + matmul_dequantize_mma_scheduler: Optional[MatmulDequantizeMMAScheduler] = None + matmul_dequantize_mma_weight_propagation_scheduler: Optional[ MatmulDequantizeMMAWeightPropagationScheduler] = None - matmul_int4_dequantize_fine_grain_scheduler: Optional[MatmulINT4DequantizeMMAScheduler] = None - matmul_int4_dequantize_weight_propagation_scheduler: Optional[ + matmul_int4_dequantize_mma_scheduler: Optional[MatmulINT4DequantizeMMAScheduler] = None + matmul_int4_dequantize_mma_weight_propagation_scheduler: Optional[ MatmulINT4DequantizeMMAWeightPropagationScheduler] = None def __init__(self, **kwargs): self.gemv_dequantize_simt_scheduler = GemvDequantizeSIMTScheduler(**kwargs) self.matmul_dequantize_simt_scheduler = MatmulDequantizeSIMTScheduler(**kwargs) self.matmul_dequantize_block_scheduler = MatmulDequantizeTileLibraryScheduler(**kwargs) - self.matmul_dequantize_fine_grained_scheduler = MatmulDequantizeMMAScheduler(**kwargs) - self.matmul_dequantize_weight_propagation_scheduler = MatmulDequantizeMMAWeightPropagationScheduler( + self.matmul_dequantize_mma_scheduler = MatmulDequantizeMMAScheduler(**kwargs) + self.matmul_dequantize_mma_weight_propagation_scheduler = MatmulDequantizeMMAWeightPropagationScheduler( **kwargs) - self.matmul_int4_dequantize_fine_grain_scheduler = MatmulINT4DequantizeMMAScheduler( - **kwargs) - self.matmul_int4_dequantize_weight_propagation_scheduler = MatmulINT4DequantizeMMAWeightPropagationScheduler( + self.matmul_int4_dequantize_mma_scheduler = MatmulINT4DequantizeMMAScheduler(**kwargs) + self.matmul_int4_dequantize_mma_weight_propagation_scheduler = MatmulINT4DequantizeMMAWeightPropagationScheduler( **kwargs) super().__init__(**kwargs) @@ -76,10 +78,10 @@ def dispatch_ampere_scheduler(self, arch: TileDevice) -> BaseScheduler: if is_tensorcore_supported_precision(in_dtype, accum_dtype, arch): if weight_transform_kind != TransformKind.NonTransform: # INT4 Can be fused into general dequantize - return (self.matmul_int4_dequantize_weight_propagation_scheduler if in_dtype - == "int4" else self.matmul_dequantize_weight_propagation_scheduler) + return (self.matmul_int4_dequantize_mma_weight_propagation_scheduler if in_dtype + == "int4" else self.matmul_dequantize_mma_weight_propagation_scheduler) else: - return self.matmul_int4_dequantize_fine_grain_scheduler if in_dtype == "int4" else self.matmul_dequantize_fine_grained_scheduler + return self.matmul_int4_dequantize_mma_scheduler if in_dtype == "int4" else self.matmul_dequantize_mma_scheduler else: if in_dtype == "int4": raise ValueError("INT4 is not supported for non-TensorCore architectures") @@ -88,8 +90,8 @@ def dispatch_ampere_scheduler(self, arch: TileDevice) -> BaseScheduler: "Weight propagation is not supported for non-TensorCore architectures") return self.matmul_dequantize_simt_scheduler else: - minimal_tensorcore_threshold: List[int, int, int] = ([8, 16, 32] if accum_dtype - == "int32" else [8, 16, 16]) + _, _, micro_size_k = get_mma_micro_size(in_dtype) + minimal_tensorcore_threshold: List[int, int, int] = [8, 16, micro_size_k] if (minimal_tensorcore_threshold[0] > M or minimal_tensorcore_threshold[1] > N or minimal_tensorcore_threshold[2] > K): if in_dtype == "int4": @@ -101,10 +103,10 @@ def dispatch_ampere_scheduler(self, arch: TileDevice) -> BaseScheduler: elif is_tensorcore_supported_precision(in_dtype, accum_dtype, arch): if self.weight_transform_kind != TransformKind.NonTransform: return ( - self.matmul_int4_dequantize_weight_propagation_scheduler - ) if in_dtype == "int4" else self.matmul_dequantize_weight_propagation_scheduler + self.matmul_int4_dequantize_mma_weight_propagation_scheduler + ) if in_dtype == "int4" else self.matmul_dequantize_mma_weight_propagation_scheduler else: - return self.matmul_int4_dequantize_fine_grain_scheduler if in_dtype == "int4" else self.matmul_dequantize_fine_grained_scheduler + return self.matmul_int4_dequantize_mma_scheduler if in_dtype == "int4" else self.matmul_dequantize_mma_scheduler else: return self.matmul_dequantize_simt_scheduler @@ -142,7 +144,10 @@ def dispatch_volta_scheduler(self, arch: TileDevice) -> BaseScheduler: return self.matmul_dequantize_simt_scheduler def dispatch_scheduler(self, arch: TileDevice) -> BaseScheduler: - if is_ampere_arch(arch): + if is_hopper_arch(arch): + logger.warning("Hopper architecture is not supported for dequantize") + return self.dispatch_ampere_scheduler(arch) + elif is_ampere_arch(arch) or is_ada_arch(arch): return self.dispatch_ampere_scheduler(arch) elif is_volta_arch(arch): return self.dispatch_volta_scheduler(arch) @@ -154,10 +159,10 @@ def detect_scheduler_from_hint(self, hint: BaseTLHint) -> BaseScheduler: self.gemv_dequantize_simt_scheduler, self.matmul_dequantize_simt_scheduler, self.matmul_dequantize_block_scheduler, - self.matmul_dequantize_fine_grained_scheduler, - self.matmul_dequantize_weight_propagation_scheduler, - self.matmul_int4_dequantize_fine_grain_scheduler, - self.matmul_int4_dequantize_weight_propagation_scheduler, + self.matmul_dequantize_mma_scheduler, + self.matmul_dequantize_mma_weight_propagation_scheduler, + self.matmul_int4_dequantize_mma_scheduler, + self.matmul_int4_dequantize_mma_weight_propagation_scheduler, ]: try: scheduler_hint_type = scheduler.get_hint_type() @@ -224,10 +229,10 @@ def set_dynamic_range(self, dynamic_range: Dict[str, int]) -> "BaseScheduler": self.gemv_dequantize_simt_scheduler, self.matmul_dequantize_simt_scheduler, self.matmul_dequantize_block_scheduler, - self.matmul_dequantize_fine_grained_scheduler, - self.matmul_dequantize_weight_propagation_scheduler, - self.matmul_int4_dequantize_fine_grain_scheduler, - self.matmul_int4_dequantize_weight_propagation_scheduler, + self.matmul_dequantize_mma_scheduler, + self.matmul_dequantize_mma_weight_propagation_scheduler, + self.matmul_int4_dequantize_mma_scheduler, + self.matmul_int4_dequantize_mma_weight_propagation_scheduler, ]: scheduler.set_dynamic_range(dynamic_range) return self @@ -238,10 +243,10 @@ def with_arch(self, arch): self.gemv_dequantize_simt_scheduler, self.matmul_dequantize_simt_scheduler, self.matmul_dequantize_block_scheduler, - self.matmul_dequantize_fine_grained_scheduler, - self.matmul_dequantize_weight_propagation_scheduler, - self.matmul_int4_dequantize_fine_grain_scheduler, - self.matmul_int4_dequantize_weight_propagation_scheduler, + self.matmul_dequantize_mma_scheduler, + self.matmul_dequantize_mma_weight_propagation_scheduler, + self.matmul_int4_dequantize_mma_scheduler, + self.matmul_int4_dequantize_mma_weight_propagation_scheduler, ]: scheduler.with_arch(arch) return self diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py index aea3d331e..eb02b0a30 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py @@ -5,10 +5,10 @@ from bitblas import tilelang as tilelang import tilelang.language as T from typing import Optional, List -from bitblas.tl.utils import ( - get_mma_micro_size, # noqa: F401 - make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401 - index_to_coordinates, # noqa: F401 +from tilelang.intrinsics.utils import ( + get_mma_micro_size, + make_mma_swizzle_layout as make_swizzle_layout, + index_to_coordinates, ) from bitblas.ops.general_matmul.tirscript import ( matmul_dequantize_select_implementation,) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py index 67330730d..b65465902 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py @@ -5,10 +5,10 @@ from tvm import DataType import tilelang.language as T from typing import Optional, List -from bitblas.tl.utils import ( - get_mma_micro_size, # noqa: F401 - make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401 - index_to_coordinates, # noqa: F401 +from tilelang.intrinsics.utils import ( + get_mma_micro_size, + make_mma_swizzle_layout as make_swizzle_layout, + index_to_coordinates, ) from bitblas.base.arch import TileDevice from bitblas.base.roller.hint import Hint diff --git a/bitblas/tl/tuner.py b/bitblas/tl/tuner.py index 0c6eded9c..539838393 100644 --- a/bitblas/tl/tuner.py +++ b/bitblas/tl/tuner.py @@ -140,11 +140,14 @@ def tvm_callback_cuda_postproc(code, _): for future in as_completed(future_to_idx, timeout=timeout): idx = future_to_idx[future] + assert idx <= len(_scheduled_ir_modules), "Index out of range" + assert idx <= len(configs), "Index out of range" + + ir_module = _scheduled_ir_modules[idx] + config = configs[idx] try: idx, code, artifact_path = future.result() - ir_module = _scheduled_ir_modules[idx] sch = tvm.tir.Schedule(ir_module) - config = configs[idx] if artifact_path is None: ARTIFACT_NOT_FOUND = f"Apply config {config} failed, artifact path is None" @@ -165,7 +168,7 @@ def tvm_callback_cuda_postproc(code, _): local_build_error = ( local_build_error[:MAX_ERROR_MESSAGE_LENGTH] + "\t...\t" + local_build_error[-MAX_ERROR_MESSAGE_LENGTH:]) - logger.error(f"An exception occurred for index {idx}: {local_build_error}") + logger.error(f"An exception occurred for hint {config}: {local_build_error}") best = None best_latency = 1e9 diff --git a/testing/python/operators/test_general_matmul_fp8.py b/testing/python/operators/test_general_matmul_fp8.py index 0cdf2d545..50e58b409 100644 --- a/testing/python/operators/test_general_matmul_fp8.py +++ b/testing/python/operators/test_general_matmul_fp8.py @@ -27,7 +27,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo with_zeros=with_zeros, zeros_mode=zeros_mode, ) - matmul = Matmul(config=matmul_config, enable_tuning=True, backend="tl") + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tl") input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) @@ -93,7 +93,7 @@ def matmul_torch_forward_weight_dequantize(M, N, K, A_dtype, W_dtype, accum_dtyp propagate_a=False, propagate_b=False, ) - matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tl") + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tir") input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N)