From a1e7c81e63cf90e26f7d7653bec1c249dd578945 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 2 Feb 2025 10:21:47 +0000 Subject: [PATCH 01/28] Refactor: Rename tensor core related classes and imports to use tile terminology --- .../general_matmul/tilelang/dense/__init__.py | 23 +- .../general_matmul/tilelang/dense/matmul.py | 34 +- .../{matmul_tensorcore.py => matmul_mma.py} | 587 +---------------- .../tilelang/dense/matmul_tile.py | 589 ++++++++++++++++++ .../tilelang/dense/matmul_wmma.py | 6 +- .../tilelang/dequantize/__init__.py | 16 +- .../tilelang/dequantize/matmul_dequantize.py | 36 +- ...inegrained.py => matmul_dequantize_mma.py} | 12 +- ...matmul_dequantize_mma_weight_transform.py} | 14 +- ...ensorcore.py => matmul_dequantize_tile.py} | 6 +- .../test_general_matmul_tilelang_impl.py | 2 +- .../test_general_matmul_tilelang_kernel.py | 54 +- .../test_general_matmul_tilelang_scheduler.py | 8 +- 13 files changed, 711 insertions(+), 676 deletions(-) rename bitblas/ops/general_matmul/tilelang/dense/{matmul_tensorcore.py => matmul_mma.py} (66%) create mode 100644 bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py rename bitblas/ops/general_matmul/tilelang/dequantize/{matmul_dequantize_tensorcore_finegrained.py => matmul_dequantize_mma.py} (98%) rename bitblas/ops/general_matmul/tilelang/dequantize/{matmul_dequantize_tensorcore_weight_transform.py => matmul_dequantize_mma_weight_transform.py} (98%) rename bitblas/ops/general_matmul/tilelang/dequantize/{matmul_dequantize_tensorcore.py => matmul_dequantize_tile.py} (99%) diff --git a/bitblas/ops/general_matmul/tilelang/dense/__init__.py b/bitblas/ops/general_matmul/tilelang/dense/__init__.py index 8ae3bc500..17843fc0f 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dense/__init__.py @@ -5,12 +5,14 @@ MatmulFineGrainSIMTScheduler, # noqa: F401 ) -from .matmul_tensorcore import ( - MatmulBlockScheduler, - MatmulFineGrainScheduler, - MatmulWeightPropagationScheduler, - MatmulINT4FineGrainScheduler, - MatmulINT4WeightPropagationScheduler, +from .matmul_tile import ( + MatmulTileLibraryScheduler,) + +from .matmul_mma import ( + MatmulMMAScheduler, + MatmulMMAWeightPropagationScheduler, + MatmulINT4MMAScheduler, + MatmulINT4MMAWeightPropagationScheduler, ) from .matmul import MatmulScheduler @@ -126,8 +128,8 @@ def is_int4_dtype(dtype): return dtype == "int4" or dtype == "uint4" if can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b): - Scheduler = MatmulWeightPropagationScheduler if not is_int4_dtype( - in_dtype) else MatmulINT4WeightPropagationScheduler + Scheduler = MatmulMMAWeightPropagationScheduler if not is_int4_dtype( + in_dtype) else MatmulINT4MMAWeightPropagationScheduler return Scheduler( M=M, N=N, @@ -140,8 +142,7 @@ def is_int4_dtype(dtype): with_bias=with_bias, ) if can_apply_fine_grain_scheduler(trans_A, trans_B, propagate_a, propagate_b): - Scheduler = MatmulFineGrainScheduler if not is_int4_dtype( - in_dtype) else MatmulINT4FineGrainScheduler + Scheduler = MatmulMMAScheduler if not is_int4_dtype(in_dtype) else MatmulINT4MMAScheduler return Scheduler( M=M, N=N, @@ -154,7 +155,7 @@ def is_int4_dtype(dtype): with_bias=with_bias, ) elif can_apply_block_scheduler(propagate_a, propagate_b): - return MatmulBlockScheduler( + return MatmulTileLibraryScheduler( M=M, N=N, K=K, diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul.py index 8b197e569..d18a77c8d 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul.py @@ -17,12 +17,13 @@ from .base import MatmulBaseParams from .gemv_simt import GemvFineGrainSIMTScheduler from .matmul_simt import MatmulFineGrainSIMTScheduler -from .matmul_tensorcore import ( - MatmulBlockScheduler, - MatmulFineGrainScheduler, - MatmulWeightPropagationScheduler, - MatmulINT4FineGrainScheduler, - MatmulINT4WeightPropagationScheduler, +from .matmul_tile import ( + MatmulTileLibraryScheduler,) +from .matmul_mma import ( + MatmulMMAScheduler, + MatmulMMAWeightPropagationScheduler, + MatmulINT4MMAScheduler, + MatmulINT4MMAWeightPropagationScheduler, ) import logging @@ -37,20 +38,21 @@ class MatmulScheduler(MatmulBaseParams): gemv_scheduler: Optional[GemvFineGrainSIMTScheduler] = None matmul_simt_scheduler: Optional[MatmulFineGrainSIMTScheduler] = None - matmul_block_scheduler: Optional[MatmulBlockScheduler] = None - matmul_fine_grain_scheduler: Optional[MatmulFineGrainScheduler] = None - matmul_weight_propagation_scheduler: Optional[MatmulWeightPropagationScheduler] = None - matmul_int4_fine_grain_scheduler: Optional[MatmulINT4FineGrainScheduler] = None - matmul_int4_weight_propagation_scheduler: Optional[MatmulINT4WeightPropagationScheduler] = None + matmul_block_scheduler: Optional[MatmulTileLibraryScheduler] = None + matmul_fine_grain_scheduler: Optional[MatmulMMAScheduler] = None + matmul_weight_propagation_scheduler: Optional[MatmulMMAWeightPropagationScheduler] = None + matmul_int4_fine_grain_scheduler: Optional[MatmulINT4MMAScheduler] = None + matmul_int4_weight_propagation_scheduler: Optional[ + MatmulINT4MMAWeightPropagationScheduler] = None def __init__(self, **kwargs): self.gemv_scheduler = GemvFineGrainSIMTScheduler(**kwargs) self.matmul_simt_scheduler = MatmulFineGrainSIMTScheduler(**kwargs) - self.matmul_block_scheduler = MatmulBlockScheduler(**kwargs) - self.matmul_fine_grain_scheduler = MatmulFineGrainScheduler(**kwargs) - self.matmul_weight_propagation_scheduler = MatmulWeightPropagationScheduler(**kwargs) - self.matmul_int4_fine_grain_scheduler = MatmulINT4FineGrainScheduler(**kwargs) - self.matmul_int4_weight_propagation_scheduler = MatmulINT4WeightPropagationScheduler( + self.matmul_block_scheduler = MatmulTileLibraryScheduler(**kwargs) + self.matmul_fine_grain_scheduler = MatmulMMAScheduler(**kwargs) + self.matmul_weight_propagation_scheduler = MatmulMMAWeightPropagationScheduler(**kwargs) + self.matmul_int4_fine_grain_scheduler = MatmulINT4MMAScheduler(**kwargs) + self.matmul_int4_weight_propagation_scheduler = MatmulINT4MMAWeightPropagationScheduler( **kwargs) super().__init__(**kwargs) diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py similarity index 66% rename from bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py rename to bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py index f2bb5bd4d..c28ebee9e 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# tile represents tile library + from bitblas import tvm as tvm from tvm import DataType import tvm.tl.language as T @@ -23,248 +25,14 @@ from dataclasses import dataclass from bitblas.ops.general_matmul.tirscript import (matmul_select_implementation) from bitblas.tl.base_hint import BaseTLHint -from .base import MatmulBaseParams +from .matmul_tile import MatmulBaseScheduler + # GPU warp configuration for NVIDIA GPUs warp_size = 32 @dataclass -class MatmulBaseScheduler(MatmulBaseParams): - - def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): - layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" - - # Simple TIR Compute Expression - ir_module = matmul_select_implementation( - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.in_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - layout=layout, - ) - - roller_hints = get_roller_hints_from_func( - ir_module, - arch, - topk, - tensorcore_only=True, - allow_gemv=True, - ) - - if roller_hints is None: - raise ValueError("No Roller Hints Found for TensorCore Scheduling") - - return self.serialize_hints_to_configs(roller_hints) - - def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): - if arch is None: - arch = self.arch - return self.get_roller_configs(arch, topk) - - # check if required shared memory cache - def check_require_cache(self) -> bool: - with_bias = self.with_bias - - conditions: List[bool] = [] - conditions.append(False) - # Bias Add should be performed in shared memory - conditions.append(with_bias) - return any(conditions) # Always set to False Currently - - -@dataclass -class MatmulBlockScheduler(MatmulBaseScheduler): - - # Default Tile Related Params - block_M: int = 64 - block_N: int = 64 - block_K: int = 32 - num_stages: int = 2 - threads: int = 128 - enable_rasterization: bool = False # Enhance L2 Locality - - class TLHint(BaseTLHint): - - hint_type = "MatmulBlockScheduler" - - def __init__(self): - super().__init__() - - @classmethod - def from_roller_hint(cls, hint: Hint): - tl_hint = cls() - for key, value in hint.__dict__.items(): - setattr(tl_hint, key, value) - - block = hint.block - warp = hint.warp - rstep = hint.rstep - num_stages = hint.pipeline_stage - rasterization_plan = hint.rasterization_plan - enable_rasterization = not isinstance(rasterization_plan, NoRasterization) - - block_row_warps = block[0] // warp[0] - block_col_warps = block[1] // warp[1] - warp_size = 32 # NVIDIA GPU warp size is 32 - if num_stages == 1: - num_stages = 0 # disable pipelining - - tl_hint.block_M = block[0] - tl_hint.block_N = block[1] - tl_hint.block_K = rstep[0] - tl_hint.num_stages = num_stages - tl_hint.threads = warp_size * block_row_warps * block_col_warps - tl_hint.enable_rasterization = enable_rasterization - - return tl_hint - - def get_config_params(self): - return { - "block_M": self.block_M, - "block_N": self.block_N, - "block_K": self.block_K, - "num_stages": self.num_stages, - "threads": self.threads, - "enable_rasterization": self.enable_rasterization, - } - - def __repr__(self): - return ("{" - f"block_M={self.block_M}," - f"block_N={self.block_N}," - f"block_K={self.block_K}," - f"num_stages={self.num_stages}," - f"threads={self.threads}," - f"enable_rasterization={self.enable_rasterization}" - "}") - - def get_configs_sm80(self): - num_stages = 2 - configs = [ - { - 'block_M': 128, - 'block_N': 256, - 'block_K': 32, - 'threads': 128 - }, - { - 'block_M': 256, - 'block_N': 128, - 'block_K': 32, - 'threads': 128 - }, - { - 'block_M': 128, - 'block_N': 128, - 'block_K': 32, - 'threads': 128 - }, - ] - configs = [{**c, 'num_stages': num_stages} for c in configs] - return configs - - def get_hint_type(self): - return self.TLHint.hint_type - - def serialize_hints_to_configs(self, hints: List[Hint]): - configs = [] - for hint in hints: - config = self.TLHint.from_roller_hint(hint) - configs.append(config) - return configs - - def with_default_config(self): - block_M = getattr(self, "block_M", 64) - block_N = getattr(self, "block_N", 64) - block_K = getattr(self, "block_K", 32) - num_stages = getattr(self, "num_stages", 2) - threads = getattr(self, "threads", 128) - enable_rasterization = getattr(self, "enable_rasterization", False) - - return self.apply_config( - block_M=block_M, - block_N=block_N, - block_K=block_K, - num_stages=num_stages, - threads=threads, - enable_rasterization=enable_rasterization, - ) - - def apply_config( - self, - block_M: Optional[int] = None, - block_N: Optional[int] = None, - block_K: Optional[int] = None, - num_stages: Optional[int] = None, - threads: Optional[int] = None, - # Enhance L2 Locality - enable_rasterization: bool = False, - ): - assert block_M is not None, "block_M is required" - assert block_N is not None, "block_N is required" - assert block_K is not None, "block_K is required" - assert num_stages is not None, "num_stages is required" - assert threads is not None, "threads is required" - - M = self.maybe_dynamic(self.M, "m") - N, K = self.N, self.K - assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" - - trans_A, trans_B = self.trans_A, self.trans_B - in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype - with_bias = self.with_bias - - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - C_shape = (M, N) - Bias_shape = (N,) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - - @T.prim_func - def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, in_dtype), - Bias: T.Buffer(Bias_shape, out_dtype), - C: T.Buffer(C_shape, out_dtype), - ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - T.use_swizzle(10, enable=enable_rasterization) - - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) - - if with_bias: - for i, j in T.Parallel(block_M, block_N): - C_local[i, j] += Bias[bx * block_N + j] - - T.copy(C_local, C[by * block_M, bx * block_N]) - - return self.post_process(main) - - def __post_init__(self): - # Add Config Validation - return - - -@dataclass -class MatmulFineGrainScheduler(MatmulBaseScheduler): +class MatmulMMAScheduler(MatmulBaseScheduler): # Fine-grained matrix multiplication scheduler # Allows for more detailed configuration. @@ -281,7 +49,7 @@ class MatmulFineGrainScheduler(MatmulBaseScheduler): class TLHint(BaseTLHint): - hint_type: str = "MatmulFineGrainScheduler" + hint_type: str = "MatmulMMAScheduler" def __init__(self): super().__init__() @@ -561,13 +329,13 @@ def __post_init__(self): @dataclass -class MatmulWeightPropagationScheduler(MatmulFineGrainScheduler): +class MatmulMMAWeightPropagationScheduler(MatmulMMAScheduler): # force set default weight transform kind to LDMatrixTransform weight_transform_kind: TransformKind = TransformKind.LDMatrixTransform - class TLHint(MatmulFineGrainScheduler.TLHint): - hint_type: str = "MatmulWeightPropagationScheduler" + class TLHint(MatmulMMAScheduler.TLHint): + hint_type: str = "MatmulMMAWeightPropagationScheduler" def apply_config( self, @@ -794,11 +562,11 @@ def is_b_smooth(self): @dataclass -class MatmulINT4FineGrainScheduler(MatmulFineGrainScheduler): +class MatmulINT4MMAScheduler(MatmulMMAScheduler): @dataclass - class TLHint(MatmulFineGrainScheduler.TLHint): - hint_type: str = "MatmulINT4FineGrainScheduler" + class TLHint(MatmulMMAScheduler.TLHint): + hint_type: str = "MatmulINT4MMAScheduler" def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" @@ -1011,10 +779,10 @@ def __post_init__(self): @dataclass -class MatmulINT4WeightPropagationScheduler(MatmulWeightPropagationScheduler): +class MatmulINT4MMAWeightPropagationScheduler(MatmulMMAWeightPropagationScheduler): - class TLHint(MatmulWeightPropagationScheduler.TLHint): - hint_type: str = "MatmulINT4WeightPropagationScheduler" + class TLHint(MatmulMMAWeightPropagationScheduler.TLHint): + hint_type: str = "MatmulINT4MMAWeightPropagationScheduler" def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" @@ -1252,328 +1020,3 @@ def __post_init__(self): assert self.trans_B is True, "Currently only support Matrix B transposed" return - - -def matmul_blocked( - M, - N, - K, - block_M=64, - block_N=64, - block_K=32, - trans_A=False, - trans_B=False, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - num_stages=2, - threads=128, - enable_rasterization: bool = False, # Enhance L2 Locality -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - C_shape = (M, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - - @T.prim_func - def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, in_dtype), - C: T.Buffer(C_shape, out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - T.use_swizzle(10, enable=enable_rasterization) - - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -def matmul_macro_tensorcore( - M, - N, - K, - in_dtype, - out_dtype, - trans_A, - trans_B, - accum_dtype, - block_row_warps, - block_col_warps, - warp_row_tiles, - warp_col_tiles, - chunk, - num_stages=2, - enable_rasterization: bool = False, -): - assert trans_A is False, "Currently only support Matrix A is not transposed" - assert trans_B is True, "Currently only support Matrix B is transposed" - - block_M = block_row_warps * warp_row_tiles - block_N = block_col_warps * warp_col_tiles - block_K = chunk - - micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) - - A_shape = (M, K) - B_shape = (N, K) - A_shared_shape = (block_M, block_K) - B_shared_shape = (block_N, block_K) - C_shared_shape = ( - block_M // micro_size_x, - block_N // micro_size_y, - micro_size_x, - micro_size_y, - ) - - warp_size = 32 # nvidia gpu warp size is 32 - threads = warp_size * (block_row_warps * block_col_warps) - local_size_a = (micro_size_x * micro_size_k) // warp_size - local_size_b = (micro_size_y * micro_size_k) // warp_size - local_size_c = (micro_size_x * micro_size_y) // warp_size - warp_rows = warp_row_tiles // micro_size_x - warp_cols = warp_col_tiles // micro_size_y - - shared_scope = "shared.dyn" # Literal["shared", "shared.dyn"] while shared for static shared memory - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - ) - - @T.prim_func - def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, in_dtype), - C: T.Buffer((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - thread_bindings = T.thread_binding(0, threads, "threadIdx.x") - - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) - - T.use_swizzle(panel_size=10, enable=enable_rasterization) - - T.clear(C_local) - - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] - - for j, k in T.Parallel(block_N, block_K): - B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] - - for ki in T.serial(0, (block_K // micro_size_k)): - - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_shared, - ki, - thread_bindings=thread_bindings, - ) - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_shared, - ki, - thread_bindings=thread_bindings, - ) - - mma_emitter.mma(A_local, B_local, C_local) - - mma_emitter.stmatrix( - C_local, - C_shared, - thread_bindings=thread_bindings, - ) - - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] - - return main - - -def matmul_macro_tensorcore_weight_propagation_level_ldmatrix( - M, - N, - K, - in_dtype, - out_dtype, - trans_A, - trans_B, - accum_dtype, - block_row_warps, - block_col_warps, - warp_row_tiles, - warp_col_tiles, - chunk, - num_stages=2, - enable_rasterization: bool = False, -): - assert trans_A is False, "Currently only support Matrix A is not transposed" - assert trans_B is True, "Currently only support Matrix B is transposed" - - block_M = block_row_warps * warp_row_tiles - block_N = block_col_warps * warp_col_tiles - block_K = chunk - - micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) - - A_shape = (M, K) - B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) - A_shared_shape = (block_M, block_K) - B_shared_shape = ( - block_N // micro_size_y, - block_K // micro_size_k, - micro_size_y, - micro_size_k, - ) - C_shared_shape = ( - block_M // micro_size_x, - block_N // micro_size_y, - micro_size_x, - micro_size_y, - ) - - warp_size = 32 # nvidia gpu warp size is 32 - threads = warp_size * (block_row_warps * block_col_warps) - local_size_a = (micro_size_x * micro_size_k) // warp_size - local_size_b = (micro_size_y * micro_size_k) // warp_size - local_size_c = (micro_size_x * micro_size_y) // warp_size - warp_rows = warp_row_tiles // micro_size_x - warp_cols = warp_col_tiles // micro_size_y - - shared_scope = "shared.dyn" # Literal["shared", "shared.dyn"] while shared for static shared memory - mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=trans_A, - b_transposed=trans_B, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - transform_kind_b=TransformKind.LDMatrixTransform, - ) - - @T.prim_func - def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, in_dtype), - C: T.Buffer((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - thread_bindings = T.thread_binding(0, threads, "threadIdx.x") - - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) - - T.use_swizzle(panel_size=10, enable=enable_rasterization) - - T.clear(C_local) - - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] - - for j, k, jj, kk in T.Parallel( - block_N // micro_size_y, - block_K // micro_size_k, - micro_size_y, - micro_size_k, - ): - B_shared[j, k, jj, kk] = B[ - bx * (block_N // micro_size_y) + j, - ko * (block_K // micro_size_k) + k, - jj, - kk, - ] - - for ki in T.serial(0, (block_K // micro_size_k)): - - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_shared, - ki, - thread_bindings=thread_bindings, - ) - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_shared, - ki, - thread_bindings=thread_bindings, - ) - - mma_emitter.mma(A_local, B_local, C_local) - - mma_emitter.stmatrix( - C_local, - C_shared, - thread_bindings=thread_bindings, - ) - - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] - - return main diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py new file mode 100644 index 000000000..cc6d166cd --- /dev/null +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py @@ -0,0 +1,589 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# tile represents tile library + +from bitblas import tvm as tvm +from tvm import DataType +import tvm.tl.language as T +from typing import Optional, List +from bitblas.tl.utils import ( + get_mma_micro_size, + make_mma_swizzle_layout as make_swizzle_layout, +) + +from bitblas.tl.mma_macro_generator import ( + TensorCoreIntrinEmitter, + TensorCoreIntrinEmitterWithLadderTransform +) +from bitblas.base.operator_common import TransformKind +from bitblas.base.arch import TileDevice +from bitblas.base.roller.hint import Hint +from bitblas.base.roller.rasterization import NoRasterization +from bitblas.base.utils import get_roller_hints_from_func +from dataclasses import dataclass +from bitblas.ops.general_matmul.tirscript import (matmul_select_implementation) +from bitblas.tl.base_hint import BaseTLHint +from .base import MatmulBaseParams +# GPU warp configuration for NVIDIA GPUs +warp_size = 32 + + +@dataclass +class MatmulBaseScheduler(MatmulBaseParams): + + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): + layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" + + # Simple TIR Compute Expression + ir_module = matmul_select_implementation( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.in_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + layout=layout, + ) + + roller_hints = get_roller_hints_from_func( + ir_module, + arch, + topk, + tensorcore_only=True, + allow_gemv=True, + ) + + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + + return self.serialize_hints_to_configs(roller_hints) + + def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): + if arch is None: + arch = self.arch + return self.get_roller_configs(arch, topk) + + # check if required shared memory cache + def check_require_cache(self) -> bool: + with_bias = self.with_bias + + conditions: List[bool] = [] + conditions.append(False) + # Bias Add should be performed in shared memory + conditions.append(with_bias) + return any(conditions) # Always set to False Currently + + +@dataclass +class MatmulTileLibraryScheduler(MatmulBaseScheduler): + + # Default Tile Related Params + block_M: int = 64 + block_N: int = 64 + block_K: int = 32 + num_stages: int = 2 + threads: int = 128 + enable_rasterization: bool = False # Enhance L2 Locality + + class TLHint(BaseTLHint): + + hint_type = "MatmulTileLibraryScheduler" + + def __init__(self): + super().__init__() + + @classmethod + def from_roller_hint(cls, hint: Hint): + tl_hint = cls() + for key, value in hint.__dict__.items(): + setattr(tl_hint, key, value) + + block = hint.block + warp = hint.warp + rstep = hint.rstep + num_stages = hint.pipeline_stage + rasterization_plan = hint.rasterization_plan + enable_rasterization = not isinstance(rasterization_plan, NoRasterization) + + block_row_warps = block[0] // warp[0] + block_col_warps = block[1] // warp[1] + warp_size = 32 # NVIDIA GPU warp size is 32 + if num_stages == 1: + num_stages = 0 # disable pipelining + + tl_hint.block_M = block[0] + tl_hint.block_N = block[1] + tl_hint.block_K = rstep[0] + tl_hint.num_stages = num_stages + tl_hint.threads = warp_size * block_row_warps * block_col_warps + tl_hint.enable_rasterization = enable_rasterization + + return tl_hint + + def get_config_params(self): + return { + "block_M": self.block_M, + "block_N": self.block_N, + "block_K": self.block_K, + "num_stages": self.num_stages, + "threads": self.threads, + "enable_rasterization": self.enable_rasterization, + } + + def __repr__(self): + return ("{" + f"block_M={self.block_M}," + f"block_N={self.block_N}," + f"block_K={self.block_K}," + f"num_stages={self.num_stages}," + f"threads={self.threads}," + f"enable_rasterization={self.enable_rasterization}" + "}") + + def get_configs_sm80(self): + num_stages = 2 + configs = [ + { + 'block_M': 128, + 'block_N': 256, + 'block_K': 32, + 'threads': 128 + }, + { + 'block_M': 256, + 'block_N': 128, + 'block_K': 32, + 'threads': 128 + }, + { + 'block_M': 128, + 'block_N': 128, + 'block_K': 32, + 'threads': 128 + }, + ] + configs = [{**c, 'num_stages': num_stages} for c in configs] + return configs + + def get_hint_type(self): + return self.TLHint.hint_type + + def serialize_hints_to_configs(self, hints: List[Hint]): + configs = [] + for hint in hints: + config = self.TLHint.from_roller_hint(hint) + configs.append(config) + return configs + + def with_default_config(self): + block_M = getattr(self, "block_M", 64) + block_N = getattr(self, "block_N", 64) + block_K = getattr(self, "block_K", 32) + num_stages = getattr(self, "num_stages", 2) + threads = getattr(self, "threads", 128) + enable_rasterization = getattr(self, "enable_rasterization", False) + + return self.apply_config( + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + enable_rasterization=enable_rasterization, + ) + + def apply_config( + self, + block_M: Optional[int] = None, + block_N: Optional[int] = None, + block_K: Optional[int] = None, + num_stages: Optional[int] = None, + threads: Optional[int] = None, + # Enhance L2 Locality + enable_rasterization: bool = False, + ): + assert block_M is not None, "block_M is required" + assert block_N is not None, "block_N is required" + assert block_K is not None, "block_K is required" + assert num_stages is not None, "num_stages is required" + assert threads is not None, "threads is required" + + M = self.maybe_dynamic(self.M, "m") + N, K = self.N, self.K + assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" + + trans_A, trans_B = self.trans_A, self.trans_B + in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + with_bias = self.with_bias + + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + C_shape = (M, N) + Bias_shape = (N,) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + Bias: T.Buffer(Bias_shape, out_dtype), + C: T.Buffer(C_shape, out_dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.use_swizzle(10, enable=enable_rasterization) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + + if with_bias: + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] += Bias[bx * block_N + j] + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return self.post_process(main) + + def __post_init__(self): + # Add Config Validation + return + + +# TODO(lei): remove these legacy functions in the future +def matmul_blocked( + M, + N, + K, + block_M=64, + block_N=64, + block_K=32, + trans_A=False, + trans_B=False, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + num_stages=2, + threads=128, + enable_rasterization: bool = False, # Enhance L2 Locality +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + C_shape = (M, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer(C_shape, out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.use_swizzle(10, enable=enable_rasterization) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def matmul_macro_tensorcore( + M, + N, + K, + in_dtype, + out_dtype, + trans_A, + trans_B, + accum_dtype, + block_row_warps, + block_col_warps, + warp_row_tiles, + warp_col_tiles, + chunk, + num_stages=2, + enable_rasterization: bool = False, +): + assert trans_A is False, "Currently only support Matrix A is not transposed" + assert trans_B is True, "Currently only support Matrix B is transposed" + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 # nvidia gpu warp size is 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + shared_scope = "shared.dyn" # Literal["shared", "shared.dyn"] while shared for static shared memory + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + T.use_swizzle(panel_size=10, enable=enable_rasterization) + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + mma_emitter.mma(A_local, B_local, C_local) + + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def matmul_macro_tensorcore_weight_propagation_level_ldmatrix( + M, + N, + K, + in_dtype, + out_dtype, + trans_A, + trans_B, + accum_dtype, + block_row_warps, + block_col_warps, + warp_row_tiles, + warp_col_tiles, + chunk, + num_stages=2, + enable_rasterization: bool = False, +): + assert trans_A is False, "Currently only support Matrix A is not transposed" + assert trans_B is True, "Currently only support Matrix B is transposed" + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) + A_shared_shape = (block_M, block_K) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 # nvidia gpu warp size is 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + shared_scope = "shared.dyn" # Literal["shared", "shared.dyn"] while shared for static shared memory + mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=trans_A, + b_transposed=trans_B, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=TransformKind.LDMatrixTransform, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + T.use_swizzle(panel_size=10, enable=enable_rasterization) + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + for j, k, jj, kk in T.Parallel( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ): + B_shared[j, k, jj, kk] = B[ + bx * (block_N // micro_size_y) + j, + ko * (block_K // micro_size_k) + k, + jj, + kk, + ] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + mma_emitter.mma(A_local, B_local, C_local) + + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_wmma.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_wmma.py index 447e7e47b..a7a307747 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_wmma.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_wmma.py @@ -6,7 +6,7 @@ from bitblas.base.roller.rasterization import NoRasterization from dataclasses import dataclass from bitblas.tl.base_hint import BaseTLHint -from .matmul_tensorcore import MatmulBaseScheduler +from .matmul_tile import MatmulBaseScheduler # GPU warp configuration for NVIDIA GPUs warp_size = 32 @@ -14,7 +14,7 @@ # TODO(lei): This is not implemented in the current version of the codebase @dataclass -class MatmulFineGrainScheduler(MatmulBaseScheduler): +class MatmulMMAScheduler(MatmulBaseScheduler): # Fine-grained matrix multiplication scheduler # Allows for more detailed configuration. @@ -31,7 +31,7 @@ class MatmulFineGrainScheduler(MatmulBaseScheduler): class TLHint(BaseTLHint): - hint_type: str = "MatmulFineGrainScheduler" + hint_type: str = "MatmulMMAScheduler" def __init__(self): super().__init__() diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py b/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py index 494a988d7..2c36e8bad 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py @@ -1,18 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .matmul_dequantize_tensorcore import ( - MatmulDequantizeBlockScheduler, # noqa: F401 +from .matmul_dequantize_tile import ( + MatmulDequantizeTileLibraryScheduler, # noqa: F401 ) -from .matmul_dequantize_tensorcore_finegrained import ( - MatmulDequantizeFineGrainedScheduler, # noqa: F401 - MatmulINT4DequantizeFineGrainedScheduler, # noqa: F401 +from .matmul_dequantize_mma import ( + MatmulDequantizeMMAScheduler, # noqa: F401 + MatmulINT4DequantizeMMAScheduler, # noqa: F401 ) -from .matmul_dequantize_tensorcore_weight_transform import ( - MatmulDequantizeWeightPropagationScheduler, # noqa: F401 - MatmulINT4DequantizeWeightPropagationScheduler, # noqa: F401 +from .matmul_dequantize_mma_weight_transform import ( + MatmulDequantizeMMAWeightPropagationScheduler, # noqa: F401 + MatmulINT4DequantizeMMAWeightPropagationScheduler, # noqa: F401 ) from .matmul_dequantize import MatmulDequantizeScheduler diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py index 9716ac075..503453760 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py @@ -17,14 +17,14 @@ from .base import MatmulDequantizeBaseParams from .gemv_dequantize_simt import GemvDequantizeSIMTScheduler from .matmul_dequantize_simt import MatmulDequantizeSIMTScheduler -from .matmul_dequantize_tensorcore import MatmulDequantizeBlockScheduler -from .matmul_dequantize_tensorcore_finegrained import ( - MatmulDequantizeFineGrainedScheduler, - MatmulINT4DequantizeFineGrainedScheduler, +from .matmul_dequantize_tile import MatmulDequantizeTileLibraryScheduler +from .matmul_dequantize_mma import ( + MatmulDequantizeMMAScheduler, + MatmulINT4DequantizeMMAScheduler, ) -from .matmul_dequantize_tensorcore_weight_transform import ( - MatmulDequantizeWeightPropagationScheduler, - MatmulINT4DequantizeWeightPropagationScheduler, +from .matmul_dequantize_mma_weight_transform import ( + MatmulDequantizeMMAWeightPropagationScheduler, + MatmulINT4DequantizeMMAWeightPropagationScheduler, ) import logging @@ -38,26 +38,24 @@ class MatmulDequantizeScheduler(MatmulDequantizeBaseParams): # Allows for more detailed configuration. gemv_dequantize_simt_scheduler: Optional[GemvDequantizeSIMTScheduler] = None matmul_dequantize_simt_scheduler: Optional[MatmulDequantizeSIMTScheduler] = None - matmul_dequantize_block_scheduler: Optional[MatmulDequantizeBlockScheduler] = None - matmul_dequantize_fine_grained_scheduler: Optional[MatmulDequantizeFineGrainedScheduler] = None + matmul_dequantize_block_scheduler: Optional[MatmulDequantizeTileLibraryScheduler] = None + matmul_dequantize_fine_grained_scheduler: Optional[MatmulDequantizeMMAScheduler] = None matmul_dequantize_weight_propagation_scheduler: Optional[ - MatmulDequantizeWeightPropagationScheduler] = None - matmul_int4_dequantize_fine_grain_scheduler: Optional[ - MatmulINT4DequantizeFineGrainedScheduler] = None + MatmulDequantizeMMAWeightPropagationScheduler] = None + matmul_int4_dequantize_fine_grain_scheduler: Optional[MatmulINT4DequantizeMMAScheduler] = None matmul_int4_dequantize_weight_propagation_scheduler: Optional[ - MatmulINT4DequantizeWeightPropagationScheduler] = None + MatmulINT4DequantizeMMAWeightPropagationScheduler] = None def __init__(self, **kwargs): self.gemv_dequantize_simt_scheduler = GemvDequantizeSIMTScheduler(**kwargs) self.matmul_dequantize_simt_scheduler = MatmulDequantizeSIMTScheduler(**kwargs) - self.matmul_dequantize_block_scheduler = MatmulDequantizeBlockScheduler(**kwargs) - self.matmul_dequantize_fine_grained_scheduler = MatmulDequantizeFineGrainedScheduler( + self.matmul_dequantize_block_scheduler = MatmulDequantizeTileLibraryScheduler(**kwargs) + self.matmul_dequantize_fine_grained_scheduler = MatmulDequantizeMMAScheduler(**kwargs) + self.matmul_dequantize_weight_propagation_scheduler = MatmulDequantizeMMAWeightPropagationScheduler( **kwargs) - self.matmul_dequantize_weight_propagation_scheduler = MatmulDequantizeWeightPropagationScheduler( + self.matmul_int4_dequantize_fine_grain_scheduler = MatmulINT4DequantizeMMAScheduler( **kwargs) - self.matmul_int4_dequantize_fine_grain_scheduler = MatmulINT4DequantizeFineGrainedScheduler( - **kwargs) - self.matmul_int4_dequantize_weight_propagation_scheduler = MatmulINT4DequantizeWeightPropagationScheduler( + self.matmul_int4_dequantize_weight_propagation_scheduler = MatmulINT4DequantizeMMAWeightPropagationScheduler( **kwargs) super().__init__(**kwargs) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py similarity index 98% rename from bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py rename to bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py index a9fa86125..f651d8c0b 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py @@ -17,7 +17,7 @@ from bitblas.base.roller.rasterization import NoRasterization from bitblas.base.utils import get_roller_hints_from_func from dataclasses import dataclass -from bitblas.ops.general_matmul.tilelang.dequantize.matmul_dequantize_tensorcore import ( +from bitblas.ops.general_matmul.tilelang.dequantize.matmul_dequantize_tile import ( MatmulDequantizeBaseScheduler, # noqa: F401 ) from bitblas.tl.base_hint import BaseTLHint @@ -27,7 +27,7 @@ @dataclass -class MatmulDequantizeFineGrainedScheduler(MatmulDequantizeBaseScheduler): +class MatmulDequantizeMMAScheduler(MatmulDequantizeBaseScheduler): # Tensor Core Warp Configuration block_row_warps: int = 2 @@ -43,7 +43,7 @@ class MatmulDequantizeFineGrainedScheduler(MatmulDequantizeBaseScheduler): class TLHint(BaseTLHint): - hint_type: str = "MatmulDequantizeFineGrainedScheduler" + hint_type: str = "MatmulDequantizeMMAScheduler" def __init__(self): super().__init__() @@ -508,10 +508,10 @@ def general_dequant_matmul( @dataclass -class MatmulINT4DequantizeFineGrainedScheduler(MatmulDequantizeFineGrainedScheduler): +class MatmulINT4DequantizeMMAScheduler(MatmulDequantizeMMAScheduler): - class TLHint(MatmulDequantizeFineGrainedScheduler.TLHint): - hint_type: str = "MatmulINT4DequantizeFineGrainedScheduler" + class TLHint(MatmulDequantizeMMAScheduler.TLHint): + hint_type: str = "MatmulINT4DequantizeMMAScheduler" def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py similarity index 98% rename from bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py rename to bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py index 56d052eda..6e24ab2da 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py @@ -11,7 +11,7 @@ ) from bitblas.base.arch import TileDevice from bitblas.base.roller.hint import Hint -from .matmul_dequantize_tensorcore_finegrained import MatmulDequantizeFineGrainedScheduler +from .matmul_dequantize_mma import MatmulDequantizeMMAScheduler from bitblas.tl.mma_macro_generator import ( TensorCoreIntrinEmitterWithLadderTransform, INT4TensorCoreIntrinEmitterWithLadderTransform, @@ -33,13 +33,13 @@ @dataclass -class MatmulDequantizeWeightPropagationScheduler(MatmulDequantizeFineGrainedScheduler): +class MatmulDequantizeMMAWeightPropagationScheduler(MatmulDequantizeMMAScheduler): # force set default weight transform kind to LDMatrixTransform weight_transform_kind: TransformKind = TransformKind.LDMatrixTransform - class TLHint(MatmulDequantizeFineGrainedScheduler.TLHint): - hint_type: str = "MatmulDequantizeWeightPropagationScheduler" + class TLHint(MatmulDequantizeMMAScheduler.TLHint): + hint_type: str = "MatmulDequantizeMMAWeightPropagationScheduler" def apply_config( self, @@ -680,10 +680,10 @@ def is_b_smooth(self): @dataclass -class MatmulINT4DequantizeWeightPropagationScheduler(MatmulDequantizeWeightPropagationScheduler): +class MatmulINT4DequantizeMMAWeightPropagationScheduler(MatmulDequantizeMMAWeightPropagationScheduler): - class TLHint(MatmulDequantizeWeightPropagationScheduler.TLHint): - hint_type: str = "MatmulINT4DequantizeWeightPropagationScheduler" + class TLHint(MatmulDequantizeMMAWeightPropagationScheduler.TLHint): + hint_type: str = "MatmulINT4DequantizeMMAWeightPropagationScheduler" def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tile.py similarity index 99% rename from bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py rename to bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tile.py index 847eb49fd..c605ad9d9 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tile.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# Tile represents Tile Library + from bitblas import tvm as tvm from tvm import DataType import tvm.tl.language as T @@ -452,7 +454,7 @@ def num_elems_per_byte(self): @dataclass -class MatmulDequantizeBlockScheduler(MatmulDequantizeBaseScheduler): +class MatmulDequantizeTileLibraryScheduler(MatmulDequantizeBaseScheduler): # Default Tile Related Params block_M: int = 128 @@ -463,7 +465,7 @@ class MatmulDequantizeBlockScheduler(MatmulDequantizeBaseScheduler): enable_rasterization: bool = False # Enhance L2 Locality class TLHint(BaseTLHint): - hint_type: str = "MatmulDequantizeBlockScheduler" + hint_type: str = "MatmulDequantizeTileLibraryScheduler" def __init__(self): super().__init__() diff --git a/testing/python/operators/test_general_matmul_tilelang_impl.py b/testing/python/operators/test_general_matmul_tilelang_impl.py index e412e2298..b05d10390 100644 --- a/testing/python/operators/test_general_matmul_tilelang_impl.py +++ b/testing/python/operators/test_general_matmul_tilelang_impl.py @@ -4,7 +4,7 @@ from bitblas import tvm as tvm import bitblas.testing from tvm import tl -from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import ( +from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( matmul_blocked, matmul_macro_tensorcore, matmul_macro_tensorcore_weight_propagation_level_ldmatrix, diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index e89701af8..ec1eb55f3 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -4,23 +4,23 @@ from bitblas import tvm as tvm import bitblas.testing from tvm import tl -from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import ( - MatmulBlockScheduler, - MatmulFineGrainScheduler, - MatmulWeightPropagationScheduler, +from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( + MatmulTileLibraryScheduler, + MatmulMMAScheduler, + MatmulMMAWeightPropagationScheduler, ) from bitblas.ops.general_matmul.tilelang.dequantize import ( MatmulDequantizeScheduler, - MatmulDequantizeFineGrainedScheduler, - MatmulDequantizeWeightPropagationScheduler, - MatmulINT4DequantizeFineGrainedScheduler, - MatmulINT4DequantizeWeightPropagationScheduler, + MatmulDequantizeMMAScheduler, + MatmulDequantizeMMAWeightPropagationScheduler, + MatmulINT4DequantizeMMAScheduler, + MatmulINT4DequantizeMMAWeightPropagationScheduler, ) -from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import ( - MatmulINT4FineGrainScheduler, - MatmulINT4WeightPropagationScheduler, +from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( + MatmulINT4MMAScheduler, + MatmulINT4MMAWeightPropagationScheduler, ) import torch @@ -41,7 +41,7 @@ def assert_matmul_blocked_with_default_correctness( out_dtype="float16", accum_dtype="float16", ): - matmul = MatmulBlockScheduler( + matmul = MatmulTileLibraryScheduler( M=M, N=N, K=K, @@ -92,7 +92,7 @@ def assert_matmul_blocked_apply_config_correctness( threads=128, enable_rasterization: bool = False, ): - matmul = MatmulBlockScheduler( + matmul = MatmulTileLibraryScheduler( M=M, N=N, K=K, @@ -145,7 +145,7 @@ def assert_matmul_fine_grained_with_default_correctness( accum_dtype="float16", ): - matmul = MatmulFineGrainScheduler( + matmul = MatmulMMAScheduler( M=M, N=N, K=K, @@ -199,7 +199,7 @@ def assert_matmul_fine_grained_apply_config_correctness( enable_rasterization: bool = False, ): - matmul = MatmulFineGrainScheduler( + matmul = MatmulMMAScheduler( M=M, N=N, K=K, @@ -253,7 +253,7 @@ def assert_matmul_weight_propagation_with_default_correctness( accum_dtype="float16", ): - matmul = MatmulWeightPropagationScheduler( + matmul = MatmulMMAWeightPropagationScheduler( M=M, N=N, K=K, @@ -319,7 +319,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( enable_rasterization: bool = False, ): - matmul = MatmulWeightPropagationScheduler( + matmul = MatmulMMAWeightPropagationScheduler( M=M, N=N, K=K, @@ -384,7 +384,7 @@ def assert_matmul_int4_fine_grained_with_default_correctness( accum_dtype="int32", ): - matmul = MatmulINT4FineGrainScheduler( + matmul = MatmulINT4MMAScheduler( M=M, N=N, K=K, @@ -441,7 +441,7 @@ def assert_matmul_int4_fine_grained_apply_config_correctness( enable_rasterization: bool = False, ): - matmul = MatmulINT4FineGrainScheduler( + matmul = MatmulINT4MMAScheduler( M=M, N=N, K=K, @@ -499,7 +499,7 @@ def assert_matmul_int4_weight_propagation_with_default_correctness( accum_dtype="int32", ): - matmul = MatmulINT4WeightPropagationScheduler( + matmul = MatmulINT4MMAWeightPropagationScheduler( M=M, N=N, K=K, @@ -569,7 +569,7 @@ def assert_matmul_int4_weight_propagation_apply_config__correctness( enable_rasterization: bool = False, ): - matmul = MatmulINT4WeightPropagationScheduler( + matmul = MatmulINT4MMAWeightPropagationScheduler( M=M, N=N, K=K, @@ -648,7 +648,7 @@ def assert_matmul_fine_grained_dequant_int4_with_default_correctness( fast_decoding=False, zeros_mode="original", ): - matmul = MatmulINT4DequantizeFineGrainedScheduler( + matmul = MatmulINT4DequantizeMMAScheduler( M=M, N=N, K=K, @@ -739,7 +739,7 @@ def assert_matmul_fine_grained_dequant_int4_apply_config_correctness( num_stages=2, enable_rasterization: bool = False, ): - matmul = MatmulINT4DequantizeFineGrainedScheduler( + matmul = MatmulINT4DequantizeMMAScheduler( M=M, N=N, K=K, @@ -831,7 +831,7 @@ def assert_matmul_weight_transform_dequant_int4_with_default_correctness( fast_decoding=False, zeros_mode="original", ): - matmul = MatmulINT4DequantizeWeightPropagationScheduler( + matmul = MatmulINT4DequantizeMMAWeightPropagationScheduler( M=M, N=N, K=K, @@ -943,7 +943,7 @@ def assert_matmul_weight_transform_dequant_int4_apply_config_correctness( num_stages=2, enable_rasterization: bool = False, ): - matmul = MatmulINT4DequantizeWeightPropagationScheduler( + matmul = MatmulINT4DequantizeMMAWeightPropagationScheduler( M=M, N=N, K=K, @@ -1195,7 +1195,7 @@ def assert_matmul_fine_grained_dequant_with_default_correctness( import numpy as np from bitblas.quantization import general_compress, interleave_weight - matmul = MatmulDequantizeFineGrainedScheduler( + matmul = MatmulDequantizeMMAScheduler( M=M, N=N, K=K, @@ -1325,7 +1325,7 @@ def assert_matmul_weight_transform_dequant_with_default_correctness( import numpy as np from bitblas.quantization import general_compress, interleave_weight - matmul = MatmulDequantizeWeightPropagationScheduler( + matmul = MatmulDequantizeMMAWeightPropagationScheduler( M=M, N=N, K=K, diff --git a/testing/python/operators/test_general_matmul_tilelang_scheduler.py b/testing/python/operators/test_general_matmul_tilelang_scheduler.py index f5b85d409..83428b7a1 100644 --- a/testing/python/operators/test_general_matmul_tilelang_scheduler.py +++ b/testing/python/operators/test_general_matmul_tilelang_scheduler.py @@ -4,8 +4,8 @@ from bitblas import tvm as tvm import bitblas.testing from tvm.ir import structural_equal -from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import ( - MatmulBlockScheduler,) +from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( + MatmulTileLibraryScheduler,) from bitblas.ops.general_matmul.tilelang.dequantize import (MatmulDequantizeScheduler) from bitblas.ops.general_matmul.tilelang.dense.gemv_simt import GemvFineGrainSIMTScheduler from bitblas.ops.general_matmul.tilelang.dense import MatmulScheduler @@ -49,7 +49,7 @@ def assert_dense_scheduler_simplify(M, in_dtype="float16", out_dtype="float16", accum_dtype="float16"): - matmul = MatmulBlockScheduler( + matmul = MatmulTileLibraryScheduler( M=M, N=N, K=K, @@ -60,7 +60,7 @@ def assert_dense_scheduler_simplify(M, accum_dtype=accum_dtype, ).deactivate_simplify().with_default_config() - simplified = MatmulBlockScheduler.Simplify(matmul) + simplified = MatmulTileLibraryScheduler.Simplify(matmul) is_equal = structural_equal(matmul, simplified) if is_equal: From 566fc679341025a3de8e3f2f1aec32f076f6bbcf Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 2 Feb 2025 10:22:06 +0000 Subject: [PATCH 02/28] lint fix --- bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py index cc6d166cd..8dceefd73 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py @@ -3,7 +3,6 @@ # tile represents tile library from bitblas import tvm as tvm -from tvm import DataType import tvm.tl.language as T from typing import Optional, List from bitblas.tl.utils import ( @@ -11,10 +10,8 @@ make_mma_swizzle_layout as make_swizzle_layout, ) -from bitblas.tl.mma_macro_generator import ( - TensorCoreIntrinEmitter, - TensorCoreIntrinEmitterWithLadderTransform -) +from bitblas.tl.mma_macro_generator import (TensorCoreIntrinEmitter, + TensorCoreIntrinEmitterWithLadderTransform) from bitblas.base.operator_common import TransformKind from bitblas.base.arch import TileDevice from bitblas.base.roller.hint import Hint From 563a70414dad813bd909fa69c934a6077269e0e1 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 2 Feb 2025 12:48:44 +0000 Subject: [PATCH 03/28] Fix: Update Matmul initialization to specify backend and clean up imports --- testing/python/operators/test_general_matmul_fp8.py | 4 ++-- .../python/operators/test_general_matmul_tilelang_kernel.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/testing/python/operators/test_general_matmul_fp8.py b/testing/python/operators/test_general_matmul_fp8.py index b21cdc8ca..0ef2e64af 100644 --- a/testing/python/operators/test_general_matmul_fp8.py +++ b/testing/python/operators/test_general_matmul_fp8.py @@ -27,7 +27,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo with_zeros=with_zeros, zeros_mode=zeros_mode, ) - matmul = Matmul(config=matmul_config, enable_tuning=True) + matmul = Matmul(config=matmul_config, enable_tuning=True, backend="tir") input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) @@ -93,7 +93,7 @@ def matmul_torch_forward_weight_dequantize(M, N, K, A_dtype, W_dtype, accum_dtyp propagate_a=False, propagate_b=False, ) - matmul = Matmul(config=matmul_config, enable_tuning=False) + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tir") input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index ec1eb55f3..a40eef639 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -6,8 +6,6 @@ from tvm import tl from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( MatmulTileLibraryScheduler, - MatmulMMAScheduler, - MatmulMMAWeightPropagationScheduler, ) from bitblas.ops.general_matmul.tilelang.dequantize import ( @@ -18,7 +16,9 @@ MatmulINT4DequantizeMMAWeightPropagationScheduler, ) -from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( +from bitblas.ops.general_matmul.tilelang.dense.matmul_mma import ( + MatmulMMAScheduler, + MatmulMMAWeightPropagationScheduler, MatmulINT4MMAScheduler, MatmulINT4MMAWeightPropagationScheduler, ) From d95fb427f700c70e73d641fe16385d66590dbb2f Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 2 Feb 2025 13:46:29 +0000 Subject: [PATCH 04/28] lint fix --- .../python/operators/test_general_matmul_tilelang_kernel.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index a40eef639..b789bd4e1 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -5,8 +5,7 @@ import bitblas.testing from tvm import tl from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( - MatmulTileLibraryScheduler, -) + MatmulTileLibraryScheduler,) from bitblas.ops.general_matmul.tilelang.dequantize import ( MatmulDequantizeScheduler, From 97ecf573898d04fb094398702eb94ab28b781ab8 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 2 Feb 2025 15:02:13 +0000 Subject: [PATCH 05/28] Update subproject commit for TVM dependency --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 41edb06ed..b372d9ca2 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 41edb06ed039944978c671afbd2dde5f22667c83 +Subproject commit b372d9ca2159a1afd5439990f68bfa29578a8bac From fe44dd733f67b18768926161720b8531ed4579cd Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 3 Feb 2025 13:21:37 +0000 Subject: [PATCH 06/28] Refactor: Update imports to use tilelang instead of tvm.tl.language --- bitblas/__init__.py | 15 ++++ bitblas/base/base_scheduler.py | 6 +- bitblas/builder/lib_generator/__init__.py | 22 +++--- bitblas/gpu/intrin/hip.py | 2 + .../general_flashatten/tilelang/flashatten.py | 3 +- .../tilelang/dense/gemv_simt.py | 3 +- .../tilelang/dense/matmul_mma.py | 3 +- .../tilelang/dense/matmul_simt.py | 3 +- .../tilelang/dense/matmul_tile.py | 3 +- .../dequantize/gemv_dequantize_simt.py | 3 +- .../dequantize/matmul_dequantize_mma.py | 3 +- .../matmul_dequantize_mma_weight_transform.py | 3 +- .../dequantize/matmul_dequantize_simt.py | 3 +- .../dequantize/matmul_dequantize_tile.py | 3 +- bitblas/ops/operator.py | 6 +- bitblas/tl/mfma_layout.py | 4 +- bitblas/tl/mfma_macro_generator.py | 3 +- bitblas/tl/mma_layout.py | 3 +- bitblas/tl/mma_macro_generator.py | 3 +- bitblas/tl/tuner.py | 6 +- bitblas/tl/wmma_macro_generator.py | 3 +- bitblas/utils/rtmod_analysis.py | 18 ++--- install.sh | 65 +++++++++++++---- install_amd.sh | 30 +++++++- .../builder/test_backend_tir_builder.py | 4 +- .../test_general_matmul_ops_backend.py | 4 ++ .../test_general_matmul_tilelang_impl.py | 14 ++-- .../test_general_matmul_tilelang_kernel.py | 70 +++++++++---------- testing/python/tilelang/test_simplifier.py | 12 ++-- .../python/tilelang/test_tilelang_amd_gemm.py | 8 +-- .../tilelang/test_tilelang_dequantize_gemm.py | 14 ++-- .../test_tilelang_dyanmic_symbolic.py | 20 +++--- .../tilelang/test_tilelang_flash_atten.py | 26 +++---- testing/python/tilelang/test_tilelang_gemm.py | 8 +-- .../tilelang/test_tilelang_gemm_s4_mma.py | 12 ++-- .../tilelang/test_tilelang_gemm_simt.py | 8 +-- .../tilelang/test_tilelang_mfma_macro_gemm.py | 8 +-- .../tilelang/test_tilelang_mma_macro_gemm.py | 24 +++---- 38 files changed, 279 insertions(+), 169 deletions(-) diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 32ff07132..69c6eb2fa 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -120,6 +120,19 @@ def remove_tvm_path(path): develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")) sys.path.insert(0, develop_tvm_path + "/python") +# TILELANG PATH +if os.environ.get("TILELANG_IMPORT_PATH", None) is None: + install_tilelang_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tilelang") + develop_tilelang_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tilelang") + if os.path.exists(install_tilelang_path): + os.environ["TILELANG_IMPORT_PATH"] = install_tilelang_path + elif (os.path.exists(develop_tilelang_path) and develop_tilelang_path not in sys.path): + os.environ["TILELANG_IMPORT_PATH"] = develop_tilelang_path + else: + logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE) + if os.environ.get("TL_CUTLASS_PATH", None) is None: install_cutlass_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass") @@ -133,6 +146,8 @@ def remove_tvm_path(path): logger.warning(CUTLASS_NOT_FOUND_MESSAGE) import tvm as tvm # noqa: E402 +import tilelang as tilelang # noqa: E402 + from .base import ( TileDevice, # noqa: F401 fast_tune, # noqa: F401 diff --git a/bitblas/base/base_scheduler.py b/bitblas/base/base_scheduler.py index d901a4192..06b605640 100644 --- a/bitblas/base/base_scheduler.py +++ b/bitblas/base/base_scheduler.py @@ -1,9 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from tvm import te from tvm import IRModule from tvm.tir import PrimFunc from typing import Optional, Union, Callable, List, Dict from dataclasses import dataclass, field -from tvm.tl.transform import Simplify +from tilelang.transform import Simplify from abc import ABC, abstractmethod from bitblas.base.arch import TileDevice, is_volta_arch, is_ampere_arch, is_cdna_arch, auto_infer_current_arch from bitblas.base.roller.hint import Hint diff --git a/bitblas/builder/lib_generator/__init__.py b/bitblas/builder/lib_generator/__init__.py index 642198060..f9cf1e627 100644 --- a/bitblas/builder/lib_generator/__init__.py +++ b/bitblas/builder/lib_generator/__init__.py @@ -66,25 +66,25 @@ def compile_lib(self, timeout: float = None, with_tl: bool = False): else: raise ValueError(f"Unsupported platform: {platform}") - if with_tl: - install_tvm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "../..", "3rdparty", "tvm") - develop_tvm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "../../..", "3rdparty", "tvm") - - tvm_root = next((path for path in [install_tvm_path, develop_tvm_path] - if os.path.exists(path) and path not in sys.path), None) + if with_tl: + install_tilelang_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../..", "3rdparty", "tilelang") + develop_tilelang_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../../..", "3rdparty", "tilelang") + + tilelang_root = next((path for path in [install_tilelang_path, develop_tilelang_path] + if os.path.exists(path) and path not in sys.path), None) if "TL_TEMPLATE_PATH " in os.environ: tl_template_path = os.environ["TL_TEMPLATE_PATH"] else: - tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl")) + tl_template_path = osp.abspath(osp.join(tilelang_root, "src")) - tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl")) + tl_template_path = osp.abspath(osp.join(tilelang_root, "src")) if "TL_CUTLASS_PATH" in os.environ: cutlass_path = os.environ["TL_CUTLASS_PATH"] else: - cutlass_path = osp.abspath(osp.join(tvm_root, "3rdparty/cutlass/include")) + cutlass_path = osp.abspath(osp.join(tilelang_root, "3rdparty/cutlass/include")) command += [ "-I" + tl_template_path, diff --git a/bitblas/gpu/intrin/hip.py b/bitblas/gpu/intrin/hip.py index 9883eaed1..4ac668e4d 100644 --- a/bitblas/gpu/intrin/hip.py +++ b/bitblas/gpu/intrin/hip.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from tvm.runtime import convert from tvm.tir.expr import Cast, IntImm from tvm.tir.function import TensorIntrin diff --git a/bitblas/ops/general_flashatten/tilelang/flashatten.py b/bitblas/ops/general_flashatten/tilelang/flashatten.py index d2a5b2857..81e00076f 100644 --- a/bitblas/ops/general_flashatten/tilelang/flashatten.py +++ b/bitblas/ops/general_flashatten/tilelang/flashatten.py @@ -3,7 +3,8 @@ from bitblas import tvm as tvm from bitblas.base.base_scheduler import BaseScheduler -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from dataclasses import dataclass from typing import Optional import logging diff --git a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py index 5891acb14..ce19f7c80 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from functools import reduce from typing import Optional, List -import tvm.tl.language as T +import tilelang.language as T from tvm import DataType from tvm.tir import PrimFunc diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py index c28ebee9e..f36c05663 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py @@ -3,8 +3,9 @@ # tile represents tile library from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from tvm import DataType -import tvm.tl.language as T +import tilelang.language as T from typing import Optional, List from bitblas.tl.utils import ( get_mma_micro_size, diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py index 5c44daae3..7d42bb1c9 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from typing import Optional, List -import tvm.tl.language as T +import tilelang.language as T from tvm import DataType from tvm.tir import PrimFunc diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py index 8dceefd73..a906bc308 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py @@ -3,7 +3,8 @@ # tile represents tile library from bitblas import tvm as tvm -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from typing import Optional, List from bitblas.tl.utils import ( get_mma_micro_size, diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py b/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py index 4f0f3b0c1..bf1d59081 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from functools import reduce from typing import Optional, List -import tvm.tl.language as T +import tilelang.language as T from tvm import DataType from tvm.tir import PrimFunc diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py index f651d8c0b..aea3d331e 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. from bitblas import tvm as tvm from tvm import DataType -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from typing import Optional, List from bitblas.tl.utils import ( get_mma_micro_size, # noqa: F401 diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py index 6e24ab2da..61ce5c631 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from tvm import DataType -import tvm.tl.language as T +import tilelang.language as T from typing import Optional, List from bitblas.tl.utils import ( get_mma_micro_size, # noqa: F401 diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py index 0fda0b2ad..b903f5140 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py @@ -3,7 +3,8 @@ from bitblas import tvm as tvm from tvm import DataType from tvm.tir import PrimFunc -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from typing import Optional, List from bitblas.base.arch import TileDevice from bitblas.base.roller.hint import Hint diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tile.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tile.py index c605ad9d9..06ab8ee7c 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tile.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tile.py @@ -4,7 +4,8 @@ from bitblas import tvm as tvm from tvm import DataType -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from typing import Optional, List from bitblas.base.arch import TileDevice from bitblas.base.roller.hint import Hint diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index ea99490c5..304723eb8 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from abc import ABC, abstractmethod -from bitblas import tvm -from tvm import tl +from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from tvm import IRModule from tvm.runtime.module import Module from tvm.target import Target @@ -192,7 +192,7 @@ def tvm_callback_hip_postproc(code, _): if self.is_tir_backend(): rt_mod = tvm.build(self.scheduled_ir_module, target=target) elif self.is_tilelang_backend(): - rt_mod = tl.lower( + rt_mod = tilelang.lower( self.scheduled_ir_module, target=target, runtime_only=True) else: raise ValueError(f"Unsupported backend: {self.backend}") diff --git a/bitblas/tl/mfma_layout.py b/bitblas/tl/mfma_layout.py index 79e75e438..af0d3a47b 100644 --- a/bitblas/tl/mfma_layout.py +++ b/bitblas/tl/mfma_layout.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - +from bitblas import tilelang as tilelang from tvm import DataType -import tvm.tl.language as T +import tilelang.language as T from tvm.runtime import convert diff --git a/bitblas/tl/mfma_macro_generator.py b/bitblas/tl/mfma_macro_generator.py index 0148cd8bf..3a0cc2582 100644 --- a/bitblas/tl/mfma_macro_generator.py +++ b/bitblas/tl/mfma_macro_generator.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from typing import Tuple from tvm import DataType from tvm.tir import PrimExpr diff --git a/bitblas/tl/mma_layout.py b/bitblas/tl/mma_layout.py index aad3ff955..c091b756a 100644 --- a/bitblas/tl/mma_layout.py +++ b/bitblas/tl/mma_layout.py @@ -3,7 +3,8 @@ from typing import Union from tvm import arith, DataType -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): diff --git a/bitblas/tl/mma_macro_generator.py b/bitblas/tl/mma_macro_generator.py index f28233911..92169cbc8 100644 --- a/bitblas/tl/mma_macro_generator.py +++ b/bitblas/tl/mma_macro_generator.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from typing import Union, Tuple, Optional from bitblas.base.operator_common import TransformKind from tvm import DataType diff --git a/bitblas/tl/tuner.py b/bitblas/tl/tuner.py index a275e91f6..0c6eded9c 100644 --- a/bitblas/tl/tuner.py +++ b/bitblas/tl/tuner.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from bitblas import tvm +from bitblas import tvm as tvm +from bitblas import tilelang as tilelang import os import logging import tempfile @@ -10,7 +11,6 @@ from tvm import IRModule from tvm.runtime import Module from tvm.tir import Schedule -import tvm.tl as tl from bitblas.tl.base_hint import BaseTLHint from bitblas.base.arch import TileDevice from bitblas.base.utils import get_dummy_input_arrays @@ -122,7 +122,7 @@ def tvm_callback_cuda_postproc(code, _): "tir.disable_cse_tir": True, **(config.pass_context if config.pass_context else {}) }): - rt_mod = tl.lower(tl_prim_func, arch.target, runtime_only=True) + rt_mod = tilelang.lower(tl_prim_func, arch.target, runtime_only=True) from tvm.contrib.tar import tar # Import the tar module diff --git a/bitblas/tl/wmma_macro_generator.py b/bitblas/tl/wmma_macro_generator.py index 0b81c1b04..48a42bd1a 100644 --- a/bitblas/tl/wmma_macro_generator.py +++ b/bitblas/tl/wmma_macro_generator.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from typing import Tuple, Optional from tvm import DataType from tvm.tir import PrimExpr diff --git a/bitblas/utils/rtmod_analysis.py b/bitblas/utils/rtmod_analysis.py index e3fe4c1cb..87b3c24c3 100644 --- a/bitblas/utils/rtmod_analysis.py +++ b/bitblas/utils/rtmod_analysis.py @@ -7,8 +7,8 @@ from tvm.target import Target from typing import Tuple, List from tvm import tir -from tvm import tl -from tvm.tl.engine import is_device_call +from bitblas import tilelang as tilelang +from tilelang.engine import is_device_call def get_annotated_device_mod_from_tl(mod: IRModule, target: Target) -> "IRModule": @@ -16,18 +16,18 @@ def get_annotated_device_mod_from_tl(mod: IRModule, target: Target) -> "IRModule target = tvm.target.Target(target, target_host) mod = tir.transform.BindTarget(target)(mod) - mod = tl.transform.FrontendLegalize()(mod) + mod = tilelang.transform.FrontendLegalize()(mod) mod = tir.transform.Simplify()(mod) - mod = tl.transform.LayoutInference()(mod) - mod = tl.transform.LowerTileOp()(mod) + mod = tilelang.transform.LayoutInference()(mod) + mod = tilelang.transform.LowerTileOp()(mod) mod = tir.transform.Simplify()(mod) if target.arch == "sm_90": - mod = tl.transform.WarpSpecializedPipeline()(mod) + mod = tilelang.transform.WarpSpecializedPipeline()(mod) else: mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) - mod = tl.transform.PipelinePlanning()(mod) - mod = tl.transform.InjectSoftwarePipeline()(mod) + mod = tilelang.transform.PipelinePlanning()(mod) + mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tir.transform.LowerOpaqueBlock()(mod) mod = tir.transform.FlattenBuffer()(mod) @@ -57,7 +57,7 @@ def get_annotated_device_mod_from_tl(mod: IRModule, target: Target) -> "IRModule # the Legalization. mod = tir.transform.LowerThreadAllreduce()(mod) mod = tir.transform.ThreadSync("shared.dyn")(mod) - mod = tl.transform.LowerHopperIntrin()(mod) + mod = tilelang.transform.LowerHopperIntrin()(mod) mod = tir.transform.InjectPTXAsyncCopy()(mod) mod = tir.transform.AnnotateDeviceRegions()(mod) diff --git a/install.sh b/install.sh index 49d1fa815..f3c7cbb76 100755 --- a/install.sh +++ b/install.sh @@ -119,22 +119,63 @@ else echo "TVM build completed successfully." fi -cd ../../.. +TVM_PREBUILD_PATH=$(realpath .) -# Step 11: Set environment variables -echo "Configuring environment variables for TVM..." -echo "export TVM_HOME=$(pwd)/3rdparty/tvm" >> ~/.bashrc -echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd):\$PYTHONPATH" >> ~/.bashrc -echo "export CUDA_DEVICE_ORDER=PCI_BUS_ID" >> ~/.bashrc +cd ../.. -# Step 12: Source .bashrc to apply changes -echo "Applying environment changes by sourcing .bashrc..." -source ~/.bashrc +echo "Building TileLang with CMake..." +cd tilelang +mkdir build +cd build + +cmake .. -DTVM_PREBUILD_PATH=$TVM_PREBUILD_PATH +if [ $? -ne 0 ]; then + echo "Error: CMake configuration failed." + exit 1 +fi + +make -j if [ $? -ne 0 ]; then - echo "Error: Failed to source .bashrc." + echo "Error: TileLang build failed." exit 1 else - echo "Environment configured successfully." + echo "TileLang build completed successfully." +fi + +echo "TileLang build completed successfully." + +cd ../../.. + +# Set environment variables +TVM_HOME_ENV="export TVM_HOME=$(pwd)/3rdparty/tvm" +TVM_EXPORT_ENV="export TVM_IMPORT_PYTHON_PATH=/root/BitBLAS/3rdparty/tvm/python" +TILELANG_HOME_ENV="export TILELANG_HOME=$(pwd)/3rdparty/tilelang" +BITBLAS_PYPATH_ENV="export PYTHONPATH=\$TVM_HOME/python:\$TILELANG_HOME:$(pwd):\$PYTHONPATH" +CUDA_DEVICE_ORDER_ENV="export CUDA_DEVICE_ORDER=PCI_BUS_ID" + +# Check and add the first line if not already present +if ! grep -qxF "$TVM_HOME_ENV" ~/.bashrc; then + echo "$TVM_HOME_ENV" >> ~/.bashrc + echo "Added TVM_HOME to ~/.bashrc" +else + echo "TVM_HOME is already set in ~/.bashrc" +fi + +# Check and add the second line if not already present +if ! grep -qxF "$BITBLAS_PYPATH_ENV" ~/.bashrc; then + echo "$BITBLAS_PYPATH_ENV" >> ~/.bashrc + echo "Added PYTHONPATH to ~/.bashrc" +else + echo "PYTHONPATH is already set in ~/.bashrc" +fi + +# Check and add the third line if not already present +if ! grep -qxF "$CUDA_DEVICE_ORDER_ENV" ~/.bashrc; then + echo "$CUDA_DEVICE_ORDER_ENV" >> ~/.bashrc + echo "Added CUDA_DEVICE_ORDER to ~/.bashrc" +else + echo "CUDA_DEVICE_ORDER is already set in ~/.bashrc" fi -echo "Installation script completed successfully." +# Reload ~/.bashrc to apply the changes +source ~/.bashrc diff --git a/install_amd.sh b/install_amd.sh index f64e442a0..716fed1a3 100755 --- a/install_amd.sh +++ b/install_amd.sh @@ -62,9 +62,37 @@ echo "set(USE_LLVM llvm-config-16)" >> config.cmake && echo "set(USE_ROCM /opt/r cmake .. && make -j && cd ../../.. +TVM_PREBUILD_PATH=$(realpath .) + +cd ../.. + +echo "Building TileLang with CMake..." +cd tilelang +mkdir build +cd build + +cmake .. -DTVM_PREBUILD_PATH=$TVM_PREBUILD_PATH +if [ $? -ne 0 ]; then + echo "Error: CMake configuration failed." + exit 1 +fi + +make -j +if [ $? -ne 0 ]; then + echo "Error: TileLang build failed." + exit 1 +else + echo "TileLang build completed successfully." +fi + +echo "TileLang build completed successfully." + +cd ../../.. + # Define the lines to be added TVM_HOME_ENV="export TVM_HOME=$(pwd)/3rdparty/tvm" -BITBLAS_PYPATH_ENV="export PYTHONPATH=\$TVM_HOME/python:$(pwd):\$PYTHONPATH" +TILELANG_HOME_ENV="export TILELANG_HOME=$(pwd)/3rdparty/tilelang" +BITBLAS_PYPATH_ENV="export PYTHONPATH=\$TVM_HOME/python:\$TILELANG_HOME:$(pwd):\$PYTHONPATH" CUDA_DEVICE_ORDER_ENV="export CUDA_DEVICE_ORDER=PCI_BUS_ID" # Check and add the first line if not already present diff --git a/testing/python/builder/test_backend_tir_builder.py b/testing/python/builder/test_backend_tir_builder.py index 84666951a..3dc0f4496 100644 --- a/testing/python/builder/test_backend_tir_builder.py +++ b/testing/python/builder/test_backend_tir_builder.py @@ -56,4 +56,6 @@ def test_matmul_transform_weight(): # fmt: on if __name__ == "__main__": - bitblas.testing.main() + # bitblas.testing.main() + matmul_backend_code_wrap(768, 768, 768, "float16", "uint4", "float16", "float16", False) + diff --git a/testing/python/operators/test_general_matmul_ops_backend.py b/testing/python/operators/test_general_matmul_ops_backend.py index d1a2253f3..8d80f7d87 100644 --- a/testing/python/operators/test_general_matmul_ops_backend.py +++ b/testing/python/operators/test_general_matmul_ops_backend.py @@ -34,6 +34,10 @@ def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, la zeros_mode=zeros_mode, ) matmul = Matmul(config=matmul_config, enable_tuning=False) + func = matmul.prim_func + import tilelang + rt_mod, params = tilelang.lower(func) + print(rt_mod) assert get_codegen_result(matmul) diff --git a/testing/python/operators/test_general_matmul_tilelang_impl.py b/testing/python/operators/test_general_matmul_tilelang_impl.py index b05d10390..3f258bab8 100644 --- a/testing/python/operators/test_general_matmul_tilelang_impl.py +++ b/testing/python/operators/test_general_matmul_tilelang_impl.py @@ -2,8 +2,8 @@ # Licensed under the MIT License. from bitblas import tvm as tvm +from bitblas import tilelang as tilelang import bitblas.testing -from tvm import tl from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( matmul_blocked, matmul_macro_tensorcore, @@ -47,7 +47,7 @@ def assert_matmul_blocked_correctness(M, enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -57,7 +57,7 @@ def assert_matmul_blocked_correctness(M, B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -105,7 +105,7 @@ def assert_matmul_macro_tensorcore_correctness( num_stages=num_stages, enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code represents generated cuda source @@ -115,7 +115,7 @@ def assert_matmul_macro_tensorcore_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -164,7 +164,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -185,7 +185,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness( LB = ladder_permutate(B.cpu()).cuda() - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, LB, C) diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index b789bd4e1..dc1f1c424 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -3,7 +3,7 @@ from bitblas import tvm as tvm import bitblas.testing -from tvm import tl +from bitblas import tilelang as tilelang from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( MatmulTileLibraryScheduler,) @@ -51,7 +51,7 @@ def assert_matmul_blocked_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -61,7 +61,7 @@ def assert_matmul_blocked_with_default_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -109,7 +109,7 @@ def assert_matmul_blocked_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -119,7 +119,7 @@ def assert_matmul_blocked_apply_config_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -155,7 +155,7 @@ def assert_matmul_fine_grained_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -163,7 +163,7 @@ def assert_matmul_fine_grained_with_default_correctness( B = (torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) if trans_B else torch.rand( K, N, device="cuda", dtype=getattr(torch, in_dtype))) - 0.5 C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) latency = mod.do_bench(mod.func, warmup=25) @@ -217,7 +217,7 @@ def assert_matmul_fine_grained_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -227,7 +227,7 @@ def assert_matmul_fine_grained_apply_config_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -263,7 +263,7 @@ def assert_matmul_weight_propagation_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -284,7 +284,7 @@ def assert_matmul_weight_propagation_with_default_correctness( LB = ladder_permutate(B.cpu()).cuda() - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, LB, C) @@ -337,7 +337,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -358,7 +358,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( LB = ladder_permutate(B.cpu()).cuda() - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, LB, C) @@ -394,7 +394,7 @@ def assert_matmul_int4_fine_grained_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -405,7 +405,7 @@ def assert_matmul_int4_fine_grained_with_default_correctness( compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") print(latency) @@ -459,7 +459,7 @@ def assert_matmul_int4_fine_grained_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -470,7 +470,7 @@ def assert_matmul_int4_fine_grained_apply_config_correctness( compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") print(latency) @@ -509,7 +509,7 @@ def assert_matmul_int4_weight_propagation_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() print(matmul) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -534,7 +534,7 @@ def assert_matmul_int4_weight_propagation_with_default_correctness( LB = ladder_permutate(compressed_B.cpu()).cuda() - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(compressed_A, LB, C) @@ -588,7 +588,7 @@ def assert_matmul_int4_weight_propagation_apply_config__correctness( ) print(matmul) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -613,7 +613,7 @@ def assert_matmul_int4_weight_propagation_apply_config__correctness( LB = ladder_permutate(compressed_B.cpu()).cuda() - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(compressed_A, LB, C) @@ -666,7 +666,7 @@ def assert_matmul_fine_grained_dequant_int4_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -692,7 +692,7 @@ def assert_matmul_fine_grained_dequant_int4_with_default_correctness( compressed_B = (B[:, ::4] & 0x03) + ((B[:, 1::4] & 0x03) << 2) + ((B[:, 2::4] & 0x03) << 4) + ( (B[:, 3::4] & 0x03) << 6) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) print(f"{compressed_B=}") if fast_decoding: lop3_compressed_B = lop3_permutate(compressed_B.cpu()).cuda() @@ -765,7 +765,7 @@ def assert_matmul_fine_grained_dequant_int4_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -791,7 +791,7 @@ def assert_matmul_fine_grained_dequant_int4_apply_config_correctness( compressed_B = (B[:, ::4] & 0x03) + ((B[:, 1::4] & 0x03) << 2) + ((B[:, 2::4] & 0x03) << 4) + ( (B[:, 3::4] & 0x03) << 6) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) print(f"{compressed_B=}") if fast_decoding: lop3_compressed_B = lop3_permutate(compressed_B.cpu()).cuda() @@ -849,7 +849,7 @@ def assert_matmul_weight_transform_dequant_int4_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() print(matmul) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -884,7 +884,7 @@ def assert_matmul_weight_transform_dequant_int4_with_default_correctness( ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) compressed_B_ladder = ladder_permutate(compressed_B.cpu()).cuda() ladder_shape = compressed_B_ladder.shape int2_shape = (ladder_shape[:-1] + (ladder_shape[-1] // 2,)) @@ -970,7 +970,7 @@ def assert_matmul_weight_transform_dequant_int4_apply_config_correctness( ) print(matmul) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -1005,7 +1005,7 @@ def assert_matmul_weight_transform_dequant_int4_apply_config_correctness( ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) compressed_B_ladder = ladder_permutate(compressed_B.cpu()).cuda() ladder_shape = compressed_B_ladder.shape int2_shape = (ladder_shape[:-1] + (ladder_shape[-1] // 2,)) @@ -1078,7 +1078,7 @@ def assert_matmul_blocked_dequant_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() print(matmul) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -1135,7 +1135,7 @@ def assert_matmul_blocked_dequant_with_default_correctness( permuted_inputs.append(inputs[2]) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(*permuted_inputs) @@ -1213,7 +1213,7 @@ def assert_matmul_fine_grained_dequant_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -1268,7 +1268,7 @@ def assert_matmul_fine_grained_dequant_with_default_correctness( permuted_inputs.append(inputs[2]) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(*permuted_inputs) @@ -1344,7 +1344,7 @@ def assert_matmul_weight_transform_dequant_with_default_correctness( ).with_default_config() if verbose: print(matmul) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -1415,7 +1415,7 @@ def assert_matmul_weight_transform_dequant_with_default_correctness( permuted_inputs.append(inputs[2]) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(*permuted_inputs) diff --git a/testing/python/tilelang/test_simplifier.py b/testing/python/tilelang/test_simplifier.py index 96536670a..18613edc9 100644 --- a/testing/python/tilelang/test_simplifier.py +++ b/testing/python/tilelang/test_simplifier.py @@ -1,6 +1,6 @@ import tvm -from tvm import tl -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T def modify( @@ -36,7 +36,7 @@ def main( def test_modify(with_B=False, with_bias=False): tester = modify(with_B=with_B, with_bias=with_bias) mod = tvm.IRModule({tester.attrs["global_symbol"]: tester}) - mod2 = tl.transform.Simplify()(mod) + mod2 = tilelang.transform.Simplify()(mod) assert mod != mod2 @@ -71,11 +71,11 @@ def main( def test_matmul(): func = matmul(1024, 1024, 1024, 128, 128, 32) mod = tvm.IRModule({func.attrs["global_symbol"]: func}) - mod = tl.transform.Simplify()(mod) + mod = tilelang.transform.Simplify()(mod) - rt_mod, params = tl.lower(mod.functions_items()[0][1], runtime_only=False) + rt_mod, params = tilelang.lower(mod.functions_items()[0][1], runtime_only=False) # TODO Profiler only support TensorType, not dynamic variable - profiler = tl.Profiler(rt_mod, params, result_idx=[2]) + profiler = tilelang.Profiler(rt_mod, params, result_idx=[2]) import torch a = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() diff --git a/testing/python/tilelang/test_tilelang_amd_gemm.py b/testing/python/tilelang/test_tilelang_amd_gemm.py index f281f8eb0..20abd415c 100644 --- a/testing/python/tilelang/test_tilelang_amd_gemm.py +++ b/testing/python/tilelang/test_tilelang_amd_gemm.py @@ -3,7 +3,7 @@ from bitblas import tvm as tvm import bitblas.testing -from tvm import tl +from bitblas import tilelang as tilelang def matmul( @@ -27,7 +27,7 @@ def matmul( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) vec_size = 4 * k_pack - import tvm.tl.language as T + import tilelang.language as T @T.prim_func def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( @@ -84,8 +84,8 @@ def run_gemm( num_threads, k_pack=k_pack, ) - mod, params = tl.lower(program) - mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) + mod, params = tilelang.lower(program) + mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) def ref_program(A, B): import torch diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index e3d47b309..2e4873f89 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -6,8 +6,8 @@ import bitblas.testing from bitblas import tvm as tvm from tvm import DataType -from tvm import tl as TL -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from bitblas.quantization import _tir_packed_to_unsigned_convert from bitblas.tl.utils import make_mma_swizzle_layout as make_swizzle_layout from bitblas.tl.mma_macro_generator import ( @@ -45,7 +45,7 @@ def matmul( local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits local_size_compressed = local_size // num_elems_per_byte - import tvm.tl.language as T + import tilelang.language as T @T.prim_func def main( @@ -123,8 +123,8 @@ def run_gemm( num_threads, ) - mod, params = TL.lower(program) - mod = TL.Profiler(mod, params, [2], TL.TensorSupplyType.Integer) + mod, params = tilelang.lower(program) + mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) out = mod.run_once() assert out is not None @@ -367,7 +367,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -406,7 +406,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct QLB = ladder_permutate(qB.cpu()).cuda() QLB = lop3_permutate(QLB.cpu()).cuda() - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, QLB, C) diff --git a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py index ae63cce9e..ae027cbf4 100644 --- a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py +++ b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py @@ -6,8 +6,8 @@ from bitblas import tvm as tvm import bitblas.testing from tvm import DataType -from tvm import tl as TL -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.tl.mma_macro_generator import (TensorCoreIntrinEmitter) @@ -178,7 +178,7 @@ def main( def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_macro(N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -188,7 +188,7 @@ def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -217,7 +217,7 @@ def tl_matmul_block( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tvm.tl.language as T + import tilelang.language as T @T.prim_func def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( @@ -271,13 +271,13 @@ def assert_tl_matmul_block_correctness( num_stages, num_threads, ) - mod, params = TL.lower(program) + mod, params = tilelang.lower(program) A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) def ref_program(A, B): @@ -318,7 +318,7 @@ def tl_matmul_block_all_dynamic( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tvm.tl.language as T + import tilelang.language as T @T.prim_func def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( @@ -370,7 +370,7 @@ def assert_tl_matmul_block_all_dynamic_correctness( num_stages, num_threads, ) - mod, params = TL.lower(program) + mod, params = tilelang.lower(program) if trans_A: A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) else: @@ -381,7 +381,7 @@ def assert_tl_matmul_block_all_dynamic_correctness( B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) def ref_program(A, B): diff --git a/testing/python/tilelang/test_tilelang_flash_atten.py b/testing/python/tilelang/test_tilelang_flash_atten.py index 3b9e33440..2c1c834ee 100644 --- a/testing/python/tilelang/test_tilelang_flash_atten.py +++ b/testing/python/tilelang/test_tilelang_flash_atten.py @@ -3,9 +3,9 @@ import bitblas import bitblas.testing from bitblas import tvm as tvm -from tvm import tl -import tvm.tl.language as T -from tvm.tl.autotuner import * +from bitblas import tilelang as tilelang +import tilelang.language as T +from tilelang.autotuner import * from functools import partial import itertools import torch @@ -66,8 +66,8 @@ def flashattn_tilelang(batch, heads, seq_len, dim, trans_K, dtypeQKV, dtypeAccu, num_stages=num_stages, is_causal=is_causal, ) - mod, params = tl.lower(tl_prim_func) - mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) + mod, params = tilelang.lower(tl_prim_func) + mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) from flash_attn.flash_attn_interface import flash_attn_func # TODO Now hack to internal function get the same input, may need to modify 3rdparty:tvm.tl.utils ins = mod._get_inputs() @@ -123,7 +123,7 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -177,8 +177,8 @@ def main( return main - mod, params = tl.lower(kernel()) - mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) + mod, params = tilelang.lower(kernel()) + mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.01, atol=0.01) @@ -206,7 +206,7 @@ def flashattn_autotune(batch, heads, seq_len, dim, is_causal): ) @jit( out_idx=[3], - supply_type=tl.TensorSupplyType.Normal, + supply_type=tilelang.TensorSupplyType.Normal, ref_prog=partial(ref_program, causal=is_causal), rtol=0.01, atol=0.01, @@ -239,7 +239,7 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -344,7 +344,7 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -398,8 +398,8 @@ def main( return main - mod, params = tl.lower(kernel()) - mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) + mod, params = tilelang.lower(kernel()) + mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.1, atol=0.1) diff --git a/testing/python/tilelang/test_tilelang_gemm.py b/testing/python/tilelang/test_tilelang_gemm.py index a4722eb99..bd26fcc1f 100644 --- a/testing/python/tilelang/test_tilelang_gemm.py +++ b/testing/python/tilelang/test_tilelang_gemm.py @@ -3,7 +3,7 @@ from bitblas import tvm as tvm import bitblas.testing -from tvm import tl +from bitblas import tilelang as tilelang def matmul( @@ -26,7 +26,7 @@ def matmul( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tvm.tl.language as T + import tilelang.language as T @T.prim_func def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( @@ -81,8 +81,8 @@ def run_gemm( num_stages, num_threads, ) - mod, params = tl.lower(program) - mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) + mod, params = tilelang.lower(program) + mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) def ref_program(A, B): import torch diff --git a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py index 42e449056..b32fd7833 100644 --- a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py +++ b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py @@ -6,8 +6,8 @@ from bitblas import tvm as tvm import bitblas.testing from tvm import DataType -from tvm import tl as TL -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from bitblas.tl.utils import ( make_mma_swizzle_layout as make_swizzle_layout,) @@ -173,7 +173,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -184,7 +184,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(compressed_A, compressed_B, C) print(C) latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") @@ -368,7 +368,7 @@ def main( def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -391,7 +391,7 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) LB = ladder_permutate(compressed_B.cpu()).cuda() mod(compressed_A, LB, C) diff --git a/testing/python/tilelang/test_tilelang_gemm_simt.py b/testing/python/tilelang/test_tilelang_gemm_simt.py index 67e2f70e2..33e5abae6 100644 --- a/testing/python/tilelang/test_tilelang_gemm_simt.py +++ b/testing/python/tilelang/test_tilelang_gemm_simt.py @@ -6,8 +6,8 @@ from bitblas import tvm as tvm import bitblas.testing from tvm import DataType -from tvm import tl as TL -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.base import simplify_prim_func @@ -142,7 +142,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_simt(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() print(src_code) # src_code is the generated cuda source @@ -157,7 +157,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) diff --git a/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py b/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py index 2f44aea85..b1f16c207 100644 --- a/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py @@ -5,8 +5,8 @@ import torch.backends from bitblas import tvm as tvm import bitblas.testing -from tvm import tl as TL -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from bitblas.tl.utils import make_mfma_swizzle_layout as make_swizzle_layout from bitblas.tl.mfma_macro_generator import ( MatrixCoreIntrinEmitter,) @@ -172,7 +172,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32"): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -186,7 +186,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) diff --git a/testing/python/tilelang/test_tilelang_mma_macro_gemm.py b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py index 660aaad89..dbdfd1034 100644 --- a/testing/python/tilelang/test_tilelang_mma_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py @@ -6,8 +6,8 @@ from bitblas import tvm as tvm import bitblas.testing from tvm import DataType -from tvm import tl as TL -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.tl.mma_macro_generator import ( TensorCoreIntrinEmitter, @@ -186,7 +186,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -200,7 +200,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -387,7 +387,7 @@ def main( def assert_tl_matmul_with_block_reduce_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_with_block_reduce(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -397,7 +397,7 @@ def assert_tl_matmul_with_block_reduce_correctness(M, N, K, in_dtype, out_dtype, B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -564,7 +564,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness(M, N, K, in_d matmul = tl_matmul_with_ladder_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -583,7 +583,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness(M, N, K, in_d ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) LB = ladder_permutate(B.cpu()).cuda() @@ -824,7 +824,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -863,7 +863,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct QLB = ladder_permutate(qB.cpu()).cuda() QLB = lop3_permutate(QLB.cpu()).cuda() - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, QLB, C) @@ -1035,7 +1035,7 @@ def assert_tl_matmul_with_ladder_input_weight_transform_correctness(M, N, K, in_ matmul = tl_matmul_with_ladder_input_weight_transform(M, N, K, in_dtype, out_dtype, accum_dtype, transform_a, transform_b) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -1068,7 +1068,7 @@ def assert_tl_matmul_with_ladder_input_weight_transform_correctness(M, N, K, in_ ladder_permutate_b = bitblas.ops.LadderPermutate(ladder_permutate_config_B) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) LA = ladder_permutate_a(A.cpu()).cuda() LB = ladder_permutate_b(B.cpu()).cuda() From 7dc068cf29aa66f5fce78ba0e9828f47895e5f65 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 3 Feb 2025 13:23:33 +0000 Subject: [PATCH 07/28] Refactor: Clean up import statements and formatting in bitblas module --- bitblas/__init__.py | 2 +- bitblas/builder/lib_generator/__init__.py | 4 ++-- testing/python/builder/test_backend_tir_builder.py | 4 +--- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 69c6eb2fa..c1ba5559a 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -146,7 +146,7 @@ def remove_tvm_path(path): logger.warning(CUTLASS_NOT_FOUND_MESSAGE) import tvm as tvm # noqa: E402 -import tilelang as tilelang # noqa: E402 +import tilelang as tilelang # noqa: E402 from .base import ( TileDevice, # noqa: F401 diff --git a/bitblas/builder/lib_generator/__init__.py b/bitblas/builder/lib_generator/__init__.py index f9cf1e627..ba4106259 100644 --- a/bitblas/builder/lib_generator/__init__.py +++ b/bitblas/builder/lib_generator/__init__.py @@ -66,12 +66,12 @@ def compile_lib(self, timeout: float = None, with_tl: bool = False): else: raise ValueError(f"Unsupported platform: {platform}") - if with_tl: + if with_tl: install_tilelang_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "../..", "3rdparty", "tilelang") develop_tilelang_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "../../..", "3rdparty", "tilelang") - + tilelang_root = next((path for path in [install_tilelang_path, develop_tilelang_path] if os.path.exists(path) and path not in sys.path), None) diff --git a/testing/python/builder/test_backend_tir_builder.py b/testing/python/builder/test_backend_tir_builder.py index 3dc0f4496..84666951a 100644 --- a/testing/python/builder/test_backend_tir_builder.py +++ b/testing/python/builder/test_backend_tir_builder.py @@ -56,6 +56,4 @@ def test_matmul_transform_weight(): # fmt: on if __name__ == "__main__": - # bitblas.testing.main() - matmul_backend_code_wrap(768, 768, 768, "float16", "uint4", "float16", "float16", False) - + bitblas.testing.main() From 691a0dc175205287e1ee96585d9af6a6c780785d Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 3 Feb 2025 17:51:01 +0000 Subject: [PATCH 08/28] Fix: Add newline injection to .bashrc if the last line is not empty in install scripts --- install.sh | 7 +++++++ install_amd.sh | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/install.sh b/install.sh index f3c7cbb76..77c258706 100755 --- a/install.sh +++ b/install.sh @@ -153,6 +153,13 @@ TILELANG_HOME_ENV="export TILELANG_HOME=$(pwd)/3rdparty/tilelang" BITBLAS_PYPATH_ENV="export PYTHONPATH=\$TVM_HOME/python:\$TILELANG_HOME:$(pwd):\$PYTHONPATH" CUDA_DEVICE_ORDER_ENV="export CUDA_DEVICE_ORDER=PCI_BUS_ID" +# Inject break line if the last line of the file is not empty +if [ -s ~/.bashrc ]; then + if [ "$(tail -c 1 ~/.bashrc)" != "" ]; then + echo "" >> ~/.bashrc + fi +fi + # Check and add the first line if not already present if ! grep -qxF "$TVM_HOME_ENV" ~/.bashrc; then echo "$TVM_HOME_ENV" >> ~/.bashrc diff --git a/install_amd.sh b/install_amd.sh index 716fed1a3..dec3dedcf 100755 --- a/install_amd.sh +++ b/install_amd.sh @@ -95,6 +95,13 @@ TILELANG_HOME_ENV="export TILELANG_HOME=$(pwd)/3rdparty/tilelang" BITBLAS_PYPATH_ENV="export PYTHONPATH=\$TVM_HOME/python:\$TILELANG_HOME:$(pwd):\$PYTHONPATH" CUDA_DEVICE_ORDER_ENV="export CUDA_DEVICE_ORDER=PCI_BUS_ID" +# Inject break line if the last line of the file is not empty +if [ -s ~/.bashrc ]; then + if [ "$(tail -c 1 ~/.bashrc)" != "" ]; then + echo "" >> ~/.bashrc + fi +fi + # Check and add the first line if not already present if ! grep -qxF "$TVM_HOME_ENV" ~/.bashrc; then echo "$TVM_HOME_ENV" >> ~/.bashrc From fd4c1a67fdb2847bfad92e9c8d0d96464a53cf01 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 3 Feb 2025 17:55:06 +0000 Subject: [PATCH 09/28] Update submodule URLs and branches for TVM and TileLang --- .gitmodules | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index d5b545545..38b621818 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,11 @@ [submodule "3rdparty/tvm"] path = 3rdparty/tvm - url = https://github.com/TileLang/tvm.git - branch = upstream + url = https://github.com/tile-ai/tvm.git + branch = tilelang_codebase +[submodule "3rdparty/tilelang"] + path = 3rdparty/tilelang + url = https://github.com/tile-ai/tilelang.git + branch = bitblas [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass url = https://github.com/TileLang/cutlass From e7ad6a95e8e8c7584ee79a43f416a6c362564a56 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 3 Feb 2025 17:56:56 +0000 Subject: [PATCH 10/28] Update tilelang submodule URL and add new subproject commit --- .gitmodules | 2 +- 3rdparty/tilelang | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) create mode 160000 3rdparty/tilelang diff --git a/.gitmodules b/.gitmodules index 38b621818..e480432aa 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,7 +4,7 @@ branch = tilelang_codebase [submodule "3rdparty/tilelang"] path = 3rdparty/tilelang - url = https://github.com/tile-ai/tilelang.git + url = https://github.com/tile-ai/tilelang branch = bitblas [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass diff --git a/3rdparty/tilelang b/3rdparty/tilelang new file mode 160000 index 000000000..e3b1856dd --- /dev/null +++ b/3rdparty/tilelang @@ -0,0 +1 @@ +Subproject commit e3b1856dd90947cc4992b5cab6537fa87ecb835e From 6da2f5f4a3353bd33d11d1a4214bd8a60da63bf0 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 3 Feb 2025 17:57:08 +0000 Subject: [PATCH 11/28] Update cutlass submodule URL to point to tile-ai repository --- .gitmodules | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitmodules b/.gitmodules index e480432aa..adbfcc33f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -8,5 +8,5 @@ branch = bitblas [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass - url = https://github.com/TileLang/cutlass + url = https://github.com/tile-ai/cutlass branch = tldev From 8f68896cd4c1e46031fc884672344af0bc678011 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 3 Feb 2025 18:08:47 +0000 Subject: [PATCH 12/28] Refactor: Split class definition for MatmulINT4DequantizeMMAWeightPropagationScheduler for improved readability --- .../dequantize/matmul_dequantize_mma_weight_transform.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py index 61ce5c631..67330730d 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py @@ -681,7 +681,8 @@ def is_b_smooth(self): @dataclass -class MatmulINT4DequantizeMMAWeightPropagationScheduler(MatmulDequantizeMMAWeightPropagationScheduler): +class MatmulINT4DequantizeMMAWeightPropagationScheduler( + MatmulDequantizeMMAWeightPropagationScheduler): class TLHint(MatmulDequantizeMMAWeightPropagationScheduler.TLHint): hint_type: str = "MatmulINT4DequantizeMMAWeightPropagationScheduler" From af0a134f15894c2db5ac75db127e8476e47f34ae Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 3 Feb 2025 19:41:21 +0000 Subject: [PATCH 13/28] Enhance environment variable handling for TVM and TileLang paths in initialization --- bitblas/__init__.py | 11 ++++++++--- setup.py | 21 ++++++++++++++++++++- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/bitblas/__init__.py b/bitblas/__init__.py index c1ba5559a..c77844cc0 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -90,7 +90,7 @@ def new_func(*args, **kwargs): if TVM_IMPORT_PYTHON_PATH is not None: os.environ["PYTHONPATH"] = (TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "")) - sys.path.insert(0, TVM_IMPORT_PYTHON_PATH + "/python") + sys.path.insert(0, TVM_IMPORT_PYTHON_PATH) else: # remove the existing tvm path in PYTHONPATH def remove_tvm_path(path): @@ -107,6 +107,7 @@ def remove_tvm_path(path): os.environ["PYTHONPATH"] = ( install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")) sys.path.insert(0, install_tvm_path + "/python") + os.environ["TVM_IMPORT_PYTHON_PATH"] = install_tvm_path + "/python" # developed 3rdparty tvm develop_tvm_path = os.path.join( @@ -119,6 +120,7 @@ def remove_tvm_path(path): os.environ["PYTHONPATH"] = ( develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")) sys.path.insert(0, develop_tvm_path + "/python") + os.environ["TVM_IMPORT_PYTHON_PATH"] = develop_tvm_path + "/python" # TILELANG PATH if os.environ.get("TILELANG_IMPORT_PATH", None) is None: @@ -127,12 +129,15 @@ def remove_tvm_path(path): develop_tilelang_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tilelang") if os.path.exists(install_tilelang_path): - os.environ["TILELANG_IMPORT_PATH"] = install_tilelang_path + os.environ["PYTHONPATH"] = install_tilelang_path + ":" + os.environ.get("PYTHONPATH", "") + sys.path.insert(0, install_tilelang_path) elif (os.path.exists(develop_tilelang_path) and develop_tilelang_path not in sys.path): - os.environ["TILELANG_IMPORT_PATH"] = develop_tilelang_path + os.environ["PYTHONPATH"] = develop_tilelang_path + ":" + os.environ.get("PYTHONPATH", "") + sys.path.insert(0, develop_tilelang_path) else: logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE) + if os.environ.get("TL_CUTLASS_PATH", None) is None: install_cutlass_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass") diff --git a/setup.py b/setup.py index 2f8bb4ce4..dfaac3076 100644 --- a/setup.py +++ b/setup.py @@ -240,7 +240,6 @@ def run(self): "3rdparty/tvm/mypy.ini", "3rdparty/tvm/pyproject.toml", "3rdparty/tvm/version.py", - "3rdparty/tvm/src/tl/tl_templates", ] for item in TVM_PREBUILD_ITEMS: source_dir = os.path.join(ROOT_DIR, item) @@ -254,6 +253,26 @@ def run(self): os.makedirs(target_dir) shutil.copy2(source_dir, target_dir) + # Copy the built TILELANG to the package directory + TILELANG_PREBUILD_ITEMS = [ + "3rdparty/tilelang/build/libtilelang_module.so", + "3rdparty/tilelang/build/libtilelang.so", + "3rdparty/tilelang/tilelang", + "3rdparty/tilelang/src/tl_templates", + "3rdparty/tilelang/VERSION", + ] + for item in TILELANG_PREBUILD_ITEMS: + source_dir = os.path.join(ROOT_DIR, item) + target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item) + if os.path.isdir(source_dir): + self.mkpath(target_dir) + distutils.dir_util.copy_tree(source_dir, target_dir) + else: + target_dir = os.path.dirname(target_dir) + if not os.path.exists(target_dir): + os.makedirs(target_dir) + shutil.copy2(source_dir, target_dir) + # Copy CUTLASS to the package directory CUTLASS_PREBUILD_ITEMS = [ "3rdparty/cutlass", From b9bb657eed68c58cfad11b8cd16ee7d263577699 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 3 Feb 2025 19:41:39 +0000 Subject: [PATCH 14/28] Remove unnecessary blank line in initialization of TILELANG_IMPORT_PATH check --- bitblas/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bitblas/__init__.py b/bitblas/__init__.py index c77844cc0..b6f4bdb35 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -137,7 +137,6 @@ def remove_tvm_path(path): else: logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE) - if os.environ.get("TL_CUTLASS_PATH", None) is None: install_cutlass_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass") From 9f7d4c652c5e0cdbfda57967dc0ce72d19ceafa5 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 4 Feb 2025 05:58:26 +0000 Subject: [PATCH 15/28] Add build_tilelang function to setup.py for TILELANG integration --- setup.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/setup.py b/setup.py index dfaac3076..9ec929045 100644 --- a/setup.py +++ b/setup.py @@ -191,6 +191,27 @@ def build_tvm(llvm_config_path): os.chdir("../../..") +def build_tilelang(TVM_PREBUILD_PATH: str = "./3rdparty/tvm/build"): + """Builds TILELANG.""" + abs_tvm_prebuilt_path = os.path.abspath(TVM_PREBUILD_PATH) + print(f"Using TVM prebuilt path: {abs_tvm_prebuilt_path}") + + os.chdir("3rdparty/tilelang") + if not os.path.exists("build"): + os.makedirs("build") + os.chdir("build") + # Run CMake and make + try: + subprocess.check_call(["cmake", "..", f"-DTVM_PREBUILD_PATH={abs_tvm_prebuilt_path}"]) + num_jobs = multiprocessing.cpu_count() + subprocess.check_call(["make", f"-j{num_jobs}"]) + except subprocess.CalledProcessError as error: + raise RuntimeError("Failed to build TILELANG") from error + finally: + # Go back to the original directory + os.chdir("../../..") + + def setup_llvm_for_tvm(): """Downloads and extracts LLVM, then configures TVM to use it.""" # Assume the download_and_extract_llvm function and its dependencies are defined elsewhere in this script @@ -209,6 +230,8 @@ def run(self): _, llvm_path = setup_llvm_for_tvm() # Build TVM build_tvm(llvm_path) + # Build TILELANG + build_tilelang() # Continue with the standard installation process install.run(self) From 1cc6886bb71adabacda4291f05bc58b3f02cf56b Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 4 Feb 2025 08:45:20 +0000 Subject: [PATCH 16/28] Add TILELANG build step in setup.py --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index 9ec929045..5244cc187 100644 --- a/setup.py +++ b/setup.py @@ -247,6 +247,8 @@ def run(self): _, llvm_path = setup_llvm_for_tvm() # Build TVM build_tvm(llvm_path) + # Build TILELANG + build_tilelang() # Copy the built TVM to the package directory TVM_PREBUILD_ITEMS = [ From 8a53cd4aea74931e48d940d2fa32750a2f6fb2f6 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 4 Feb 2025 09:17:23 +0000 Subject: [PATCH 17/28] Update TileLang subproject and improve type mappings in TLCUDASourceWrapper --- 3rdparty/tilelang | 2 +- bitblas/builder/wrapper/tl.py | 8 ++++---- bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py | 2 +- testing/python/operators/test_general_matmul_fp8.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/3rdparty/tilelang b/3rdparty/tilelang index e3b1856dd..6aef1f896 160000 --- a/3rdparty/tilelang +++ b/3rdparty/tilelang @@ -1 +1 @@ -Subproject commit e3b1856dd90947cc4992b5cab6537fa87ecb835e +Subproject commit 6aef1f8968bb3f8f806b74eb1334eb2a44a9ab3a diff --git a/bitblas/builder/wrapper/tl.py b/bitblas/builder/wrapper/tl.py index 85a75601a..c64ca2b8d 100644 --- a/bitblas/builder/wrapper/tl.py +++ b/bitblas/builder/wrapper/tl.py @@ -18,10 +18,10 @@ class TLCUDASourceWrapper(object): _TYPE_MAP = { "float32": "float", - "float16": "half_t", - "bfloat16": "bfloat16_t", - "e4m3_float8": "float_e4m3_t", - "e5m2_float8": "float_e5m2_t", + "float16": "half", + "bfloat16": "__nv_bfloat16", + "e4m3_float8": "__nv_fp8_e4m3", + "e5m2_float8": "__nv_fp8_e5m2", "float64": "double", "int64": "int64_t", "int32": "int", diff --git a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py index ce19f7c80..74ab74b76 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py @@ -158,7 +158,7 @@ def main( ) else: for ki in T.serial(micro_size_k): - accum_res[0] += A_local[ki] * B_local[ki] + accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(accum_dtype) with T.attr( T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), diff --git a/testing/python/operators/test_general_matmul_fp8.py b/testing/python/operators/test_general_matmul_fp8.py index 0ef2e64af..0cdf2d545 100644 --- a/testing/python/operators/test_general_matmul_fp8.py +++ b/testing/python/operators/test_general_matmul_fp8.py @@ -27,7 +27,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo with_zeros=with_zeros, zeros_mode=zeros_mode, ) - matmul = Matmul(config=matmul_config, enable_tuning=True, backend="tir") + matmul = Matmul(config=matmul_config, enable_tuning=True, backend="tl") input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) @@ -93,7 +93,7 @@ def matmul_torch_forward_weight_dequantize(M, N, K, A_dtype, W_dtype, accum_dtyp propagate_a=False, propagate_b=False, ) - matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tir") + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tl") input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) From 3cac227fc31d9c1b79af746540165006a658005e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 4 Feb 2025 09:17:38 +0000 Subject: [PATCH 18/28] Refactor line continuation for better readability in gemv_simt.py --- bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py index 74ab74b76..2f9eaa7c8 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py @@ -158,7 +158,8 @@ def main( ) else: for ki in T.serial(micro_size_k): - accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype(accum_dtype) + accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype( + accum_dtype) with T.attr( T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), From 494434e33440ed770f61632b543e2d5c8d75bf8d Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 4 Feb 2025 15:21:19 +0000 Subject: [PATCH 19/28] Update TileLang subproject to latest commit --- 3rdparty/tilelang | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tilelang b/3rdparty/tilelang index 6aef1f896..d0dbc46dc 160000 --- a/3rdparty/tilelang +++ b/3rdparty/tilelang @@ -1 +1 @@ -Subproject commit 6aef1f8968bb3f8f806b74eb1334eb2a44a9ab3a +Subproject commit d0dbc46dc788db50bd65297f336717d24ebf97da From 6d8e05d134453661c6a17fc994c6fc1683175aff Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 4 Feb 2025 16:31:27 +0000 Subject: [PATCH 20/28] Update subproject commit for TVM to the latest version --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index b372d9ca2..d310bd5aa 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit b372d9ca2159a1afd5439990f68bfa29578a8bac +Subproject commit d310bd5aadce96145546fb7a87a6d325ea392b2b From 34d77e4b5ad30a8bcf59b1e7d44d3676947ff3a9 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 5 Feb 2025 16:52:31 +0000 Subject: [PATCH 21/28] Update TileLang subproject and fix CUDA architecture checks --- 3rdparty/tilelang | 2 +- bitblas/base/arch/cuda.py | 2 +- bitblas/gpu/matmul_analysis.py | 2 +- .../general_matmul/tilelang/dense/__init__.py | 8 +-- .../general_matmul/tilelang/dense/matmul.py | 63 ++++++++++-------- .../tilelang/dense/matmul_mma.py | 2 +- .../tilelang/dense/matmul_tile.py | 2 +- .../tilelang/dequantize/matmul_dequantize.py | 64 ++++++++++--------- .../dequantize/matmul_dequantize_mma.py | 8 +-- .../matmul_dequantize_mma_weight_transform.py | 8 +-- 10 files changed, 87 insertions(+), 74 deletions(-) diff --git a/3rdparty/tilelang b/3rdparty/tilelang index d0dbc46dc..fffda9358 160000 --- a/3rdparty/tilelang +++ b/3rdparty/tilelang @@ -1 +1 @@ -Subproject commit d0dbc46dc788db50bd65297f336717d24ebf97da +Subproject commit fffda93581572d3393ec2f96483533ffa6f72c1e diff --git a/bitblas/base/arch/cuda.py b/bitblas/base/arch/cuda.py index 5e8730d67..25c83bff1 100644 --- a/bitblas/base/arch/cuda.py +++ b/bitblas/base/arch/cuda.py @@ -27,7 +27,7 @@ def is_volta_arch(arch: TileDevice) -> bool: def is_ampere_arch(arch: TileDevice) -> bool: conditions = [True] conditions.append(is_cuda_arch(arch)) - conditions.append(arch.sm_version >= 80 and arch.sm_version < 90) + conditions.append(arch.sm_version >= 80 and arch.sm_version < 89) return all(conditions) diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index d43d95ffa..f954f3f88 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -649,7 +649,7 @@ def check_last_trait(region: List[Range]): if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70: in_dtype, out_dtype = get_in_out_dtypes(block_stmt) if not is_tensorcore_supported_precision(in_dtype, out_dtype, arch=get_arch(target)): - logger.debug("The input and output dtype is not supported by tensorcore") + logger.debug(f"The input and output dtype ({in_dtype}, {out_dtype})is not supported by tensorcore") return func, None # reindex and transform functions diff --git a/bitblas/ops/general_matmul/tilelang/dense/__init__.py b/bitblas/ops/general_matmul/tilelang/dense/__init__.py index 17843fc0f..888bbf2c9 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dense/__init__.py @@ -102,7 +102,7 @@ def ampere_select_scheduler( trans_A, trans_B = parse_layout(layout) - def can_apply_fine_grain_scheduler(trans_A, trans_B, propagate_a, propagate_b): + def can_apply_mma_scheduler(trans_A, trans_B, propagate_a, propagate_b): conditions = [] conditions.append(trans_A is False) conditions.append(trans_B is True) @@ -116,7 +116,7 @@ def can_apply_block_scheduler(propagate_a, propagate_b): conditions.append(propagate_b == TransformKind.NonTransform) return all(conditions) - def can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b): + def can_apply_mma_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b): conditions = [] conditions.append(trans_A is False) conditions.append(trans_B is True) @@ -127,7 +127,7 @@ def can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propag def is_int4_dtype(dtype): return dtype == "int4" or dtype == "uint4" - if can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b): + if can_apply_mma_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b): Scheduler = MatmulMMAWeightPropagationScheduler if not is_int4_dtype( in_dtype) else MatmulINT4MMAWeightPropagationScheduler return Scheduler( @@ -141,7 +141,7 @@ def is_int4_dtype(dtype): accum_dtype=accum_dtype, with_bias=with_bias, ) - if can_apply_fine_grain_scheduler(trans_A, trans_B, propagate_a, propagate_b): + if can_apply_mma_scheduler(trans_A, trans_B, propagate_a, propagate_b): Scheduler = MatmulMMAScheduler if not is_int4_dtype(in_dtype) else MatmulINT4MMAScheduler return Scheduler( M=M, diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul.py index d18a77c8d..7735b6a77 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul.py @@ -9,8 +9,11 @@ TileDevice, is_ampere_arch, is_volta_arch, + is_ada_arch, + is_hopper_arch, is_tensorcore_supported_precision, ) +from tilelang.intrinsics.utils import get_mma_micro_size from dataclasses import dataclass from bitblas.tl.base_hint import BaseTLHint @@ -39,20 +42,20 @@ class MatmulScheduler(MatmulBaseParams): gemv_scheduler: Optional[GemvFineGrainSIMTScheduler] = None matmul_simt_scheduler: Optional[MatmulFineGrainSIMTScheduler] = None matmul_block_scheduler: Optional[MatmulTileLibraryScheduler] = None - matmul_fine_grain_scheduler: Optional[MatmulMMAScheduler] = None - matmul_weight_propagation_scheduler: Optional[MatmulMMAWeightPropagationScheduler] = None - matmul_int4_fine_grain_scheduler: Optional[MatmulINT4MMAScheduler] = None - matmul_int4_weight_propagation_scheduler: Optional[ + matmul_mma_scheduler: Optional[MatmulMMAScheduler] = None + matmul_mma_weight_propagation_scheduler: Optional[MatmulMMAWeightPropagationScheduler] = None + matmul_int4_mma_scheduler: Optional[MatmulINT4MMAScheduler] = None + matmul_int4_mma_weight_propagation_scheduler: Optional[ MatmulINT4MMAWeightPropagationScheduler] = None def __init__(self, **kwargs): self.gemv_scheduler = GemvFineGrainSIMTScheduler(**kwargs) self.matmul_simt_scheduler = MatmulFineGrainSIMTScheduler(**kwargs) self.matmul_block_scheduler = MatmulTileLibraryScheduler(**kwargs) - self.matmul_fine_grain_scheduler = MatmulMMAScheduler(**kwargs) - self.matmul_weight_propagation_scheduler = MatmulMMAWeightPropagationScheduler(**kwargs) - self.matmul_int4_fine_grain_scheduler = MatmulINT4MMAScheduler(**kwargs) - self.matmul_int4_weight_propagation_scheduler = MatmulINT4MMAWeightPropagationScheduler( + self.matmul_mma_scheduler = MatmulMMAScheduler(**kwargs) + self.matmul_mma_weight_propagation_scheduler = MatmulMMAWeightPropagationScheduler(**kwargs) + self.matmul_int4_mma_scheduler = MatmulINT4MMAScheduler(**kwargs) + self.matmul_int4_mma_weight_propagation_scheduler = MatmulINT4MMAWeightPropagationScheduler( **kwargs) super().__init__(**kwargs) @@ -72,14 +75,14 @@ def dispatch_ampere_scheduler(self, arch: TileDevice) -> BaseScheduler: if is_tensorcore_supported_precision(in_dtype, accum_dtype, arch): if weight_transform_kind != TransformKind.NonTransform: # INT4 Can be fused into general dequantize - return self.matmul_int4_weight_propagation_scheduler if in_dtype == "int4" else self.matmul_weight_propagation_scheduler - return self.matmul_int4_fine_grain_scheduler if in_dtype == "int4" else self.matmul_fine_grain_scheduler + return self.matmul_int4_mma_weight_propagation_scheduler if in_dtype == "int4" else self.matmul_mma_weight_propagation_scheduler + return self.matmul_int4_mma_scheduler if in_dtype == "int4" else self.matmul_mma_scheduler else: return self.matmul_simt_scheduler else: + _, _, micro_size_k = get_mma_micro_size(in_dtype) minimal_tensorcore_threshold: List[int, int, - int] = [8, 16, 32 - ] if accum_dtype == "int32" else [8, 16, 16] + int] = [8, 16, micro_size_k] if minimal_tensorcore_threshold[0] > M or minimal_tensorcore_threshold[ 1] > N or minimal_tensorcore_threshold[2] > K: if in_dtype == "int4": @@ -90,10 +93,11 @@ def dispatch_ampere_scheduler(self, arch: TileDevice) -> BaseScheduler: return self.gemv_scheduler elif is_tensorcore_supported_precision(in_dtype, accum_dtype, arch): if self.weight_transform_kind != TransformKind.NonTransform: - return (self.matmul_int4_weight_propagation_scheduler - if in_dtype == "int4" else self.matmul_weight_propagation_scheduler) + return (self.matmul_int4_mma_weight_propagation_scheduler + if in_dtype == "int4" else self.matmul_mma_weight_propagation_scheduler) else: - return self.matmul_int4_fine_grain_scheduler if in_dtype == "int4" else self.matmul_block_scheduler + # by default, use the mma_scheduler + return self.matmul_int4_mma_scheduler if in_dtype == "int4" else self.matmul_mma_scheduler else: return self.matmul_simt_scheduler @@ -131,7 +135,10 @@ def dispatch_volta_scheduler(self, arch: TileDevice) -> BaseScheduler: return self.matmul_simt_scheduler def dispatch_scheduler(self, arch: TileDevice) -> BaseScheduler: - if is_ampere_arch(arch): + if is_hopper_arch(arch): + logger.warning("Hopper architecture is not fully supported yet, fallback to Ada") + return self.dispatch_ampere_scheduler(arch) + elif is_ampere_arch(arch) or is_ada_arch(arch): return self.dispatch_ampere_scheduler(arch) elif is_volta_arch(arch): return self.dispatch_volta_scheduler(arch) @@ -143,10 +150,10 @@ def detect_scheduler_from_hint(self, hint: BaseTLHint) -> BaseScheduler: self.gemv_scheduler, self.matmul_simt_scheduler, self.matmul_block_scheduler, - self.matmul_fine_grain_scheduler, - self.matmul_weight_propagation_scheduler, - self.matmul_int4_fine_grain_scheduler, - self.matmul_int4_weight_propagation_scheduler, + self.matmul_mma_scheduler, + self.matmul_mma_weight_propagation_scheduler, + self.matmul_int4_mma_scheduler, + self.matmul_int4_mma_weight_propagation_scheduler, ]: try: scheduler_hint_type = scheduler.get_hint_type() @@ -213,10 +220,10 @@ def set_dynamic_range(self, dynamic_range: Dict[str, int]) -> "BaseScheduler": self.gemv_scheduler, self.matmul_simt_scheduler, self.matmul_block_scheduler, - self.matmul_fine_grain_scheduler, - self.matmul_weight_propagation_scheduler, - self.matmul_int4_fine_grain_scheduler, - self.matmul_int4_weight_propagation_scheduler, + self.matmul_mma_scheduler, + self.matmul_mma_weight_propagation_scheduler, + self.matmul_int4_mma_scheduler, + self.matmul_int4_mma_weight_propagation_scheduler, ]: scheduler.set_dynamic_range(dynamic_range) return self @@ -227,10 +234,10 @@ def with_arch(self, arch): self.gemv_scheduler, self.matmul_simt_scheduler, self.matmul_block_scheduler, - self.matmul_fine_grain_scheduler, - self.matmul_weight_propagation_scheduler, - self.matmul_int4_fine_grain_scheduler, - self.matmul_int4_weight_propagation_scheduler, + self.matmul_mma_scheduler, + self.matmul_mma_weight_propagation_scheduler, + self.matmul_int4_mma_scheduler, + self.matmul_int4_mma_weight_propagation_scheduler, ]: scheduler.with_arch(arch) return self diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py index f36c05663..531035647 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py @@ -7,7 +7,7 @@ from tvm import DataType import tilelang.language as T from typing import Optional, List -from bitblas.tl.utils import ( +from tilelang.intrinsics.utils import ( get_mma_micro_size, make_mma_swizzle_layout as make_swizzle_layout, ) diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py index a906bc308..763184e90 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py @@ -6,7 +6,7 @@ from bitblas import tilelang as tilelang import tilelang.language as T from typing import Optional, List -from bitblas.tl.utils import ( +from tilelang.intrinsics.utils import ( get_mma_micro_size, make_mma_swizzle_layout as make_swizzle_layout, ) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py index 503453760..574015950 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py @@ -9,8 +9,11 @@ TileDevice, is_ampere_arch, is_volta_arch, + is_ada_arch, + is_hopper_arch, is_tensorcore_supported_precision, ) +from tilelang.intrinsics.utils import get_mma_micro_size from dataclasses import dataclass from bitblas.tl.base_hint import BaseTLHint @@ -39,23 +42,23 @@ class MatmulDequantizeScheduler(MatmulDequantizeBaseParams): gemv_dequantize_simt_scheduler: Optional[GemvDequantizeSIMTScheduler] = None matmul_dequantize_simt_scheduler: Optional[MatmulDequantizeSIMTScheduler] = None matmul_dequantize_block_scheduler: Optional[MatmulDequantizeTileLibraryScheduler] = None - matmul_dequantize_fine_grained_scheduler: Optional[MatmulDequantizeMMAScheduler] = None - matmul_dequantize_weight_propagation_scheduler: Optional[ + matmul_dequantize_mma_scheduler: Optional[MatmulDequantizeMMAScheduler] = None + matmul_dequantize_mma_weight_propagation_scheduler: Optional[ MatmulDequantizeMMAWeightPropagationScheduler] = None - matmul_int4_dequantize_fine_grain_scheduler: Optional[MatmulINT4DequantizeMMAScheduler] = None - matmul_int4_dequantize_weight_propagation_scheduler: Optional[ + matmul_int4_dequantize_mma_scheduler: Optional[MatmulINT4DequantizeMMAScheduler] = None + matmul_int4_dequantize_mma_weight_propagation_scheduler: Optional[ MatmulINT4DequantizeMMAWeightPropagationScheduler] = None def __init__(self, **kwargs): self.gemv_dequantize_simt_scheduler = GemvDequantizeSIMTScheduler(**kwargs) self.matmul_dequantize_simt_scheduler = MatmulDequantizeSIMTScheduler(**kwargs) self.matmul_dequantize_block_scheduler = MatmulDequantizeTileLibraryScheduler(**kwargs) - self.matmul_dequantize_fine_grained_scheduler = MatmulDequantizeMMAScheduler(**kwargs) - self.matmul_dequantize_weight_propagation_scheduler = MatmulDequantizeMMAWeightPropagationScheduler( + self.matmul_dequantize_mma_scheduler = MatmulDequantizeMMAScheduler(**kwargs) + self.matmul_dequantize_mma_weight_propagation_scheduler = MatmulDequantizeMMAWeightPropagationScheduler( **kwargs) - self.matmul_int4_dequantize_fine_grain_scheduler = MatmulINT4DequantizeMMAScheduler( + self.matmul_int4_dequantize_mma_scheduler = MatmulINT4DequantizeMMAScheduler( **kwargs) - self.matmul_int4_dequantize_weight_propagation_scheduler = MatmulINT4DequantizeMMAWeightPropagationScheduler( + self.matmul_int4_dequantize_mma_weight_propagation_scheduler = MatmulINT4DequantizeMMAWeightPropagationScheduler( **kwargs) super().__init__(**kwargs) @@ -76,10 +79,10 @@ def dispatch_ampere_scheduler(self, arch: TileDevice) -> BaseScheduler: if is_tensorcore_supported_precision(in_dtype, accum_dtype, arch): if weight_transform_kind != TransformKind.NonTransform: # INT4 Can be fused into general dequantize - return (self.matmul_int4_dequantize_weight_propagation_scheduler if in_dtype - == "int4" else self.matmul_dequantize_weight_propagation_scheduler) + return (self.matmul_int4_dequantize_mma_weight_propagation_scheduler if in_dtype + == "int4" else self.matmul_dequantize_mma_weight_propagation_scheduler) else: - return self.matmul_int4_dequantize_fine_grain_scheduler if in_dtype == "int4" else self.matmul_dequantize_fine_grained_scheduler + return self.matmul_int4_dequantize_mma_scheduler if in_dtype == "int4" else self.matmul_dequantize_mma_scheduler else: if in_dtype == "int4": raise ValueError("INT4 is not supported for non-TensorCore architectures") @@ -88,8 +91,8 @@ def dispatch_ampere_scheduler(self, arch: TileDevice) -> BaseScheduler: "Weight propagation is not supported for non-TensorCore architectures") return self.matmul_dequantize_simt_scheduler else: - minimal_tensorcore_threshold: List[int, int, int] = ([8, 16, 32] if accum_dtype - == "int32" else [8, 16, 16]) + _, _, micro_size_k = get_mma_micro_size(in_dtype) + minimal_tensorcore_threshold: List[int, int, int] = [8, 16, micro_size_k] if (minimal_tensorcore_threshold[0] > M or minimal_tensorcore_threshold[1] > N or minimal_tensorcore_threshold[2] > K): if in_dtype == "int4": @@ -101,10 +104,10 @@ def dispatch_ampere_scheduler(self, arch: TileDevice) -> BaseScheduler: elif is_tensorcore_supported_precision(in_dtype, accum_dtype, arch): if self.weight_transform_kind != TransformKind.NonTransform: return ( - self.matmul_int4_dequantize_weight_propagation_scheduler - ) if in_dtype == "int4" else self.matmul_dequantize_weight_propagation_scheduler + self.matmul_int4_dequantize_mma_weight_propagation_scheduler + ) if in_dtype == "int4" else self.matmul_dequantize_mma_weight_propagation_scheduler else: - return self.matmul_int4_dequantize_fine_grain_scheduler if in_dtype == "int4" else self.matmul_dequantize_fine_grained_scheduler + return self.matmul_int4_dequantize_mma_scheduler if in_dtype == "int4" else self.matmul_dequantize_mma_scheduler else: return self.matmul_dequantize_simt_scheduler @@ -142,7 +145,10 @@ def dispatch_volta_scheduler(self, arch: TileDevice) -> BaseScheduler: return self.matmul_dequantize_simt_scheduler def dispatch_scheduler(self, arch: TileDevice) -> BaseScheduler: - if is_ampere_arch(arch): + if is_hopper_arch(arch): + logger.warning("Hopper architecture is not supported for dequantize") + return self.dispatch_ampere_scheduler(arch) + elif is_ampere_arch(arch) or is_ada_arch(arch): return self.dispatch_ampere_scheduler(arch) elif is_volta_arch(arch): return self.dispatch_volta_scheduler(arch) @@ -154,10 +160,10 @@ def detect_scheduler_from_hint(self, hint: BaseTLHint) -> BaseScheduler: self.gemv_dequantize_simt_scheduler, self.matmul_dequantize_simt_scheduler, self.matmul_dequantize_block_scheduler, - self.matmul_dequantize_fine_grained_scheduler, - self.matmul_dequantize_weight_propagation_scheduler, - self.matmul_int4_dequantize_fine_grain_scheduler, - self.matmul_int4_dequantize_weight_propagation_scheduler, + self.matmul_dequantize_mma_scheduler, + self.matmul_dequantize_mma_weight_propagation_scheduler, + self.matmul_int4_dequantize_mma_scheduler, + self.matmul_int4_dequantize_mma_weight_propagation_scheduler, ]: try: scheduler_hint_type = scheduler.get_hint_type() @@ -224,10 +230,10 @@ def set_dynamic_range(self, dynamic_range: Dict[str, int]) -> "BaseScheduler": self.gemv_dequantize_simt_scheduler, self.matmul_dequantize_simt_scheduler, self.matmul_dequantize_block_scheduler, - self.matmul_dequantize_fine_grained_scheduler, - self.matmul_dequantize_weight_propagation_scheduler, - self.matmul_int4_dequantize_fine_grain_scheduler, - self.matmul_int4_dequantize_weight_propagation_scheduler, + self.matmul_dequantize_mma_scheduler, + self.matmul_dequantize_mma_weight_propagation_scheduler, + self.matmul_int4_dequantize_mma_scheduler, + self.matmul_int4_dequantize_mma_weight_propagation_scheduler, ]: scheduler.set_dynamic_range(dynamic_range) return self @@ -238,10 +244,10 @@ def with_arch(self, arch): self.gemv_dequantize_simt_scheduler, self.matmul_dequantize_simt_scheduler, self.matmul_dequantize_block_scheduler, - self.matmul_dequantize_fine_grained_scheduler, - self.matmul_dequantize_weight_propagation_scheduler, - self.matmul_int4_dequantize_fine_grain_scheduler, - self.matmul_int4_dequantize_weight_propagation_scheduler, + self.matmul_dequantize_mma_scheduler, + self.matmul_dequantize_mma_weight_propagation_scheduler, + self.matmul_int4_dequantize_mma_scheduler, + self.matmul_int4_dequantize_mma_weight_propagation_scheduler, ]: scheduler.with_arch(arch) return self diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py index aea3d331e..eb02b0a30 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py @@ -5,10 +5,10 @@ from bitblas import tilelang as tilelang import tilelang.language as T from typing import Optional, List -from bitblas.tl.utils import ( - get_mma_micro_size, # noqa: F401 - make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401 - index_to_coordinates, # noqa: F401 +from tilelang.intrinsics.utils import ( + get_mma_micro_size, + make_mma_swizzle_layout as make_swizzle_layout, + index_to_coordinates, ) from bitblas.ops.general_matmul.tirscript import ( matmul_dequantize_select_implementation,) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py index 67330730d..b65465902 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py @@ -5,10 +5,10 @@ from tvm import DataType import tilelang.language as T from typing import Optional, List -from bitblas.tl.utils import ( - get_mma_micro_size, # noqa: F401 - make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401 - index_to_coordinates, # noqa: F401 +from tilelang.intrinsics.utils import ( + get_mma_micro_size, + make_mma_swizzle_layout as make_swizzle_layout, + index_to_coordinates, ) from bitblas.base.arch import TileDevice from bitblas.base.roller.hint import Hint From 895080ed4f34c63fef820790ba7c21383cfeaa7e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 5 Feb 2025 16:56:59 +0000 Subject: [PATCH 22/28] Disable tuning for Matmul in FP8 tests and update backend configuration --- testing/python/operators/test_general_matmul_fp8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/testing/python/operators/test_general_matmul_fp8.py b/testing/python/operators/test_general_matmul_fp8.py index 0cdf2d545..50e58b409 100644 --- a/testing/python/operators/test_general_matmul_fp8.py +++ b/testing/python/operators/test_general_matmul_fp8.py @@ -27,7 +27,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo with_zeros=with_zeros, zeros_mode=zeros_mode, ) - matmul = Matmul(config=matmul_config, enable_tuning=True, backend="tl") + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tl") input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) @@ -93,7 +93,7 @@ def matmul_torch_forward_weight_dequantize(M, N, K, A_dtype, W_dtype, accum_dtyp propagate_a=False, propagate_b=False, ) - matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tl") + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tir") input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) From ab990efaa62142f97954ee566d0c7cdfb12b0bfe Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 6 Feb 2025 13:50:52 +0000 Subject: [PATCH 23/28] Update TileLang subproject to latest commit --- 3rdparty/tilelang | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tilelang b/3rdparty/tilelang index fffda9358..d0396a609 160000 --- a/3rdparty/tilelang +++ b/3rdparty/tilelang @@ -1 +1 @@ -Subproject commit fffda93581572d3393ec2f96483533ffa6f72c1e +Subproject commit d0396a609ab21fb84174a65e2a188b0c265b271f From 9d4bcc944a061504bff864e8d7b467d90f53358d Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 6 Feb 2025 13:50:58 +0000 Subject: [PATCH 24/28] Refactor code for improved readability by adjusting line breaks and formatting in matmul analysis and dequantization modules --- bitblas/gpu/matmul_analysis.py | 4 +++- bitblas/ops/general_matmul/tilelang/dense/matmul.py | 3 +-- .../general_matmul/tilelang/dequantize/matmul_dequantize.py | 3 +-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index f954f3f88..607f81fff 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -649,7 +649,9 @@ def check_last_trait(region: List[Range]): if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70: in_dtype, out_dtype = get_in_out_dtypes(block_stmt) if not is_tensorcore_supported_precision(in_dtype, out_dtype, arch=get_arch(target)): - logger.debug(f"The input and output dtype ({in_dtype}, {out_dtype})is not supported by tensorcore") + logger.debug( + f"The input and output dtype ({in_dtype}, {out_dtype})is not supported by tensorcore" + ) return func, None # reindex and transform functions diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul.py index 7735b6a77..ef1bfed37 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul.py @@ -81,8 +81,7 @@ def dispatch_ampere_scheduler(self, arch: TileDevice) -> BaseScheduler: return self.matmul_simt_scheduler else: _, _, micro_size_k = get_mma_micro_size(in_dtype) - minimal_tensorcore_threshold: List[int, int, - int] = [8, 16, micro_size_k] + minimal_tensorcore_threshold: List[int, int, int] = [8, 16, micro_size_k] if minimal_tensorcore_threshold[0] > M or minimal_tensorcore_threshold[ 1] > N or minimal_tensorcore_threshold[2] > K: if in_dtype == "int4": diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py index 574015950..562714fe2 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py @@ -56,8 +56,7 @@ def __init__(self, **kwargs): self.matmul_dequantize_mma_scheduler = MatmulDequantizeMMAScheduler(**kwargs) self.matmul_dequantize_mma_weight_propagation_scheduler = MatmulDequantizeMMAWeightPropagationScheduler( **kwargs) - self.matmul_int4_dequantize_mma_scheduler = MatmulINT4DequantizeMMAScheduler( - **kwargs) + self.matmul_int4_dequantize_mma_scheduler = MatmulINT4DequantizeMMAScheduler(**kwargs) self.matmul_int4_dequantize_mma_weight_propagation_scheduler = MatmulINT4DequantizeMMAWeightPropagationScheduler( **kwargs) From b6edc52c47bbed208fb46c2d7b860a70c301cca2 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 7 Feb 2025 07:09:04 +0000 Subject: [PATCH 25/28] Update float16 type mapping to use 'half_t' in TLCUDASourceWrapper --- bitblas/builder/wrapper/tl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitblas/builder/wrapper/tl.py b/bitblas/builder/wrapper/tl.py index c64ca2b8d..e3eba9883 100644 --- a/bitblas/builder/wrapper/tl.py +++ b/bitblas/builder/wrapper/tl.py @@ -18,7 +18,7 @@ class TLCUDASourceWrapper(object): _TYPE_MAP = { "float32": "float", - "float16": "half", + "float16": "half_t", "bfloat16": "__nv_bfloat16", "e4m3_float8": "__nv_fp8_e4m3", "e5m2_float8": "__nv_fp8_e5m2", From 167c82b09e8f1b7b68248482bb87f0e18b9912ff Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 7 Feb 2025 07:35:16 +0000 Subject: [PATCH 26/28] Update TileLang subproject and improve error logging in tuner.py --- 3rdparty/tilelang | 2 +- bitblas/tl/tuner.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/3rdparty/tilelang b/3rdparty/tilelang index d0396a609..b09e2b5cc 160000 --- a/3rdparty/tilelang +++ b/3rdparty/tilelang @@ -1 +1 @@ -Subproject commit d0396a609ab21fb84174a65e2a188b0c265b271f +Subproject commit b09e2b5cc6abfe94c35249cb99ad899ef394964e diff --git a/bitblas/tl/tuner.py b/bitblas/tl/tuner.py index 0c6eded9c..755d6996a 100644 --- a/bitblas/tl/tuner.py +++ b/bitblas/tl/tuner.py @@ -140,11 +140,14 @@ def tvm_callback_cuda_postproc(code, _): for future in as_completed(future_to_idx, timeout=timeout): idx = future_to_idx[future] + assert idx <= len(_scheduled_ir_modules), "Index out of range" + assert idx <= len(configs), "Index out of range" + + ir_module = _scheduled_ir_modules[idx] + config = configs[idx] try: idx, code, artifact_path = future.result() - ir_module = _scheduled_ir_modules[idx] sch = tvm.tir.Schedule(ir_module) - config = configs[idx] if artifact_path is None: ARTIFACT_NOT_FOUND = f"Apply config {config} failed, artifact path is None" @@ -165,7 +168,7 @@ def tvm_callback_cuda_postproc(code, _): local_build_error = ( local_build_error[:MAX_ERROR_MESSAGE_LENGTH] + "\t...\t" + local_build_error[-MAX_ERROR_MESSAGE_LENGTH:]) - logger.error(f"An exception occurred for index {idx}: {local_build_error}") + logger.error(f"An exception occurred for hint {config}: {local_build_error}") best = None best_latency = 1e9 From 9b86b6d22745ebb04a8612c2b06bf5211a8033ef Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 8 Feb 2025 04:19:04 +0000 Subject: [PATCH 27/28] Remove unnecessary whitespace in tuner.py for cleaner code --- bitblas/tl/tuner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitblas/tl/tuner.py b/bitblas/tl/tuner.py index 755d6996a..539838393 100644 --- a/bitblas/tl/tuner.py +++ b/bitblas/tl/tuner.py @@ -142,7 +142,7 @@ def tvm_callback_cuda_postproc(code, _): idx = future_to_idx[future] assert idx <= len(_scheduled_ir_modules), "Index out of range" assert idx <= len(configs), "Index out of range" - + ir_module = _scheduled_ir_modules[idx] config = configs[idx] try: From 4fd0670f2a9a1c7213052b9ffac6211c8e43e7f4 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 8 Feb 2025 09:43:10 +0000 Subject: [PATCH 28/28] Update bfloat16 type mapping to use 'bfloat16_t' in TLCUDASourceWrapper --- bitblas/builder/wrapper/tl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitblas/builder/wrapper/tl.py b/bitblas/builder/wrapper/tl.py index e3eba9883..5c635fed9 100644 --- a/bitblas/builder/wrapper/tl.py +++ b/bitblas/builder/wrapper/tl.py @@ -19,7 +19,7 @@ class TLCUDASourceWrapper(object): _TYPE_MAP = { "float32": "float", "float16": "half_t", - "bfloat16": "__nv_bfloat16", + "bfloat16": "bfloat16_t", "e4m3_float8": "__nv_fp8_e4m3", "e5m2_float8": "__nv_fp8_e5m2", "float64": "double",