diff --git a/bitblas/ops/general_matmul/tilelang/dense/__init__.py b/bitblas/ops/general_matmul/tilelang/dense/__init__.py index 8ae3bc500..17843fc0f 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dense/__init__.py @@ -5,12 +5,14 @@ MatmulFineGrainSIMTScheduler, # noqa: F401 ) -from .matmul_tensorcore import ( - MatmulBlockScheduler, - MatmulFineGrainScheduler, - MatmulWeightPropagationScheduler, - MatmulINT4FineGrainScheduler, - MatmulINT4WeightPropagationScheduler, +from .matmul_tile import ( + MatmulTileLibraryScheduler,) + +from .matmul_mma import ( + MatmulMMAScheduler, + MatmulMMAWeightPropagationScheduler, + MatmulINT4MMAScheduler, + MatmulINT4MMAWeightPropagationScheduler, ) from .matmul import MatmulScheduler @@ -126,8 +128,8 @@ def is_int4_dtype(dtype): return dtype == "int4" or dtype == "uint4" if can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b): - Scheduler = MatmulWeightPropagationScheduler if not is_int4_dtype( - in_dtype) else MatmulINT4WeightPropagationScheduler + Scheduler = MatmulMMAWeightPropagationScheduler if not is_int4_dtype( + in_dtype) else MatmulINT4MMAWeightPropagationScheduler return Scheduler( M=M, N=N, @@ -140,8 +142,7 @@ def is_int4_dtype(dtype): with_bias=with_bias, ) if can_apply_fine_grain_scheduler(trans_A, trans_B, propagate_a, propagate_b): - Scheduler = MatmulFineGrainScheduler if not is_int4_dtype( - in_dtype) else MatmulINT4FineGrainScheduler + Scheduler = MatmulMMAScheduler if not is_int4_dtype(in_dtype) else MatmulINT4MMAScheduler return Scheduler( M=M, N=N, @@ -154,7 +155,7 @@ def is_int4_dtype(dtype): with_bias=with_bias, ) elif can_apply_block_scheduler(propagate_a, propagate_b): - return MatmulBlockScheduler( + return MatmulTileLibraryScheduler( M=M, N=N, K=K, diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul.py index 8b197e569..d18a77c8d 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul.py @@ -17,12 +17,13 @@ from .base import MatmulBaseParams from .gemv_simt import GemvFineGrainSIMTScheduler from .matmul_simt import MatmulFineGrainSIMTScheduler -from .matmul_tensorcore import ( - MatmulBlockScheduler, - MatmulFineGrainScheduler, - MatmulWeightPropagationScheduler, - MatmulINT4FineGrainScheduler, - MatmulINT4WeightPropagationScheduler, +from .matmul_tile import ( + MatmulTileLibraryScheduler,) +from .matmul_mma import ( + MatmulMMAScheduler, + MatmulMMAWeightPropagationScheduler, + MatmulINT4MMAScheduler, + MatmulINT4MMAWeightPropagationScheduler, ) import logging @@ -37,20 +38,21 @@ class MatmulScheduler(MatmulBaseParams): gemv_scheduler: Optional[GemvFineGrainSIMTScheduler] = None matmul_simt_scheduler: Optional[MatmulFineGrainSIMTScheduler] = None - matmul_block_scheduler: Optional[MatmulBlockScheduler] = None - matmul_fine_grain_scheduler: Optional[MatmulFineGrainScheduler] = None - matmul_weight_propagation_scheduler: Optional[MatmulWeightPropagationScheduler] = None - matmul_int4_fine_grain_scheduler: Optional[MatmulINT4FineGrainScheduler] = None - matmul_int4_weight_propagation_scheduler: Optional[MatmulINT4WeightPropagationScheduler] = None + matmul_block_scheduler: Optional[MatmulTileLibraryScheduler] = None + matmul_fine_grain_scheduler: Optional[MatmulMMAScheduler] = None + matmul_weight_propagation_scheduler: Optional[MatmulMMAWeightPropagationScheduler] = None + matmul_int4_fine_grain_scheduler: Optional[MatmulINT4MMAScheduler] = None + matmul_int4_weight_propagation_scheduler: Optional[ + MatmulINT4MMAWeightPropagationScheduler] = None def __init__(self, **kwargs): self.gemv_scheduler = GemvFineGrainSIMTScheduler(**kwargs) self.matmul_simt_scheduler = MatmulFineGrainSIMTScheduler(**kwargs) - self.matmul_block_scheduler = MatmulBlockScheduler(**kwargs) - self.matmul_fine_grain_scheduler = MatmulFineGrainScheduler(**kwargs) - self.matmul_weight_propagation_scheduler = MatmulWeightPropagationScheduler(**kwargs) - self.matmul_int4_fine_grain_scheduler = MatmulINT4FineGrainScheduler(**kwargs) - self.matmul_int4_weight_propagation_scheduler = MatmulINT4WeightPropagationScheduler( + self.matmul_block_scheduler = MatmulTileLibraryScheduler(**kwargs) + self.matmul_fine_grain_scheduler = MatmulMMAScheduler(**kwargs) + self.matmul_weight_propagation_scheduler = MatmulMMAWeightPropagationScheduler(**kwargs) + self.matmul_int4_fine_grain_scheduler = MatmulINT4MMAScheduler(**kwargs) + self.matmul_int4_weight_propagation_scheduler = MatmulINT4MMAWeightPropagationScheduler( **kwargs) super().__init__(**kwargs) diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py similarity index 66% rename from bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py rename to bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py index f2bb5bd4d..c28ebee9e 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# tile represents tile library + from bitblas import tvm as tvm from tvm import DataType import tvm.tl.language as T @@ -23,248 +25,14 @@ from dataclasses import dataclass from bitblas.ops.general_matmul.tirscript import (matmul_select_implementation) from bitblas.tl.base_hint import BaseTLHint -from .base import MatmulBaseParams +from .matmul_tile import MatmulBaseScheduler + # GPU warp configuration for NVIDIA GPUs warp_size = 32 @dataclass -class MatmulBaseScheduler(MatmulBaseParams): - - def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): - layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" - - # Simple TIR Compute Expression - ir_module = matmul_select_implementation( - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.in_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - layout=layout, - ) - - roller_hints = get_roller_hints_from_func( - ir_module, - arch, - topk, - tensorcore_only=True, - allow_gemv=True, - ) - - if roller_hints is None: - raise ValueError("No Roller Hints Found for TensorCore Scheduling") - - return self.serialize_hints_to_configs(roller_hints) - - def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): - if arch is None: - arch = self.arch - return self.get_roller_configs(arch, topk) - - # check if required shared memory cache - def check_require_cache(self) -> bool: - with_bias = self.with_bias - - conditions: List[bool] = [] - conditions.append(False) - # Bias Add should be performed in shared memory - conditions.append(with_bias) - return any(conditions) # Always set to False Currently - - -@dataclass -class MatmulBlockScheduler(MatmulBaseScheduler): - - # Default Tile Related Params - block_M: int = 64 - block_N: int = 64 - block_K: int = 32 - num_stages: int = 2 - threads: int = 128 - enable_rasterization: bool = False # Enhance L2 Locality - - class TLHint(BaseTLHint): - - hint_type = "MatmulBlockScheduler" - - def __init__(self): - super().__init__() - - @classmethod - def from_roller_hint(cls, hint: Hint): - tl_hint = cls() - for key, value in hint.__dict__.items(): - setattr(tl_hint, key, value) - - block = hint.block - warp = hint.warp - rstep = hint.rstep - num_stages = hint.pipeline_stage - rasterization_plan = hint.rasterization_plan - enable_rasterization = not isinstance(rasterization_plan, NoRasterization) - - block_row_warps = block[0] // warp[0] - block_col_warps = block[1] // warp[1] - warp_size = 32 # NVIDIA GPU warp size is 32 - if num_stages == 1: - num_stages = 0 # disable pipelining - - tl_hint.block_M = block[0] - tl_hint.block_N = block[1] - tl_hint.block_K = rstep[0] - tl_hint.num_stages = num_stages - tl_hint.threads = warp_size * block_row_warps * block_col_warps - tl_hint.enable_rasterization = enable_rasterization - - return tl_hint - - def get_config_params(self): - return { - "block_M": self.block_M, - "block_N": self.block_N, - "block_K": self.block_K, - "num_stages": self.num_stages, - "threads": self.threads, - "enable_rasterization": self.enable_rasterization, - } - - def __repr__(self): - return ("{" - f"block_M={self.block_M}," - f"block_N={self.block_N}," - f"block_K={self.block_K}," - f"num_stages={self.num_stages}," - f"threads={self.threads}," - f"enable_rasterization={self.enable_rasterization}" - "}") - - def get_configs_sm80(self): - num_stages = 2 - configs = [ - { - 'block_M': 128, - 'block_N': 256, - 'block_K': 32, - 'threads': 128 - }, - { - 'block_M': 256, - 'block_N': 128, - 'block_K': 32, - 'threads': 128 - }, - { - 'block_M': 128, - 'block_N': 128, - 'block_K': 32, - 'threads': 128 - }, - ] - configs = [{**c, 'num_stages': num_stages} for c in configs] - return configs - - def get_hint_type(self): - return self.TLHint.hint_type - - def serialize_hints_to_configs(self, hints: List[Hint]): - configs = [] - for hint in hints: - config = self.TLHint.from_roller_hint(hint) - configs.append(config) - return configs - - def with_default_config(self): - block_M = getattr(self, "block_M", 64) - block_N = getattr(self, "block_N", 64) - block_K = getattr(self, "block_K", 32) - num_stages = getattr(self, "num_stages", 2) - threads = getattr(self, "threads", 128) - enable_rasterization = getattr(self, "enable_rasterization", False) - - return self.apply_config( - block_M=block_M, - block_N=block_N, - block_K=block_K, - num_stages=num_stages, - threads=threads, - enable_rasterization=enable_rasterization, - ) - - def apply_config( - self, - block_M: Optional[int] = None, - block_N: Optional[int] = None, - block_K: Optional[int] = None, - num_stages: Optional[int] = None, - threads: Optional[int] = None, - # Enhance L2 Locality - enable_rasterization: bool = False, - ): - assert block_M is not None, "block_M is required" - assert block_N is not None, "block_N is required" - assert block_K is not None, "block_K is required" - assert num_stages is not None, "num_stages is required" - assert threads is not None, "threads is required" - - M = self.maybe_dynamic(self.M, "m") - N, K = self.N, self.K - assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" - - trans_A, trans_B = self.trans_A, self.trans_B - in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype - with_bias = self.with_bias - - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - C_shape = (M, N) - Bias_shape = (N,) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - - @T.prim_func - def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, in_dtype), - Bias: T.Buffer(Bias_shape, out_dtype), - C: T.Buffer(C_shape, out_dtype), - ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - T.use_swizzle(10, enable=enable_rasterization) - - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) - - if with_bias: - for i, j in T.Parallel(block_M, block_N): - C_local[i, j] += Bias[bx * block_N + j] - - T.copy(C_local, C[by * block_M, bx * block_N]) - - return self.post_process(main) - - def __post_init__(self): - # Add Config Validation - return - - -@dataclass -class MatmulFineGrainScheduler(MatmulBaseScheduler): +class MatmulMMAScheduler(MatmulBaseScheduler): # Fine-grained matrix multiplication scheduler # Allows for more detailed configuration. @@ -281,7 +49,7 @@ class MatmulFineGrainScheduler(MatmulBaseScheduler): class TLHint(BaseTLHint): - hint_type: str = "MatmulFineGrainScheduler" + hint_type: str = "MatmulMMAScheduler" def __init__(self): super().__init__() @@ -561,13 +329,13 @@ def __post_init__(self): @dataclass -class MatmulWeightPropagationScheduler(MatmulFineGrainScheduler): +class MatmulMMAWeightPropagationScheduler(MatmulMMAScheduler): # force set default weight transform kind to LDMatrixTransform weight_transform_kind: TransformKind = TransformKind.LDMatrixTransform - class TLHint(MatmulFineGrainScheduler.TLHint): - hint_type: str = "MatmulWeightPropagationScheduler" + class TLHint(MatmulMMAScheduler.TLHint): + hint_type: str = "MatmulMMAWeightPropagationScheduler" def apply_config( self, @@ -794,11 +562,11 @@ def is_b_smooth(self): @dataclass -class MatmulINT4FineGrainScheduler(MatmulFineGrainScheduler): +class MatmulINT4MMAScheduler(MatmulMMAScheduler): @dataclass - class TLHint(MatmulFineGrainScheduler.TLHint): - hint_type: str = "MatmulINT4FineGrainScheduler" + class TLHint(MatmulMMAScheduler.TLHint): + hint_type: str = "MatmulINT4MMAScheduler" def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" @@ -1011,10 +779,10 @@ def __post_init__(self): @dataclass -class MatmulINT4WeightPropagationScheduler(MatmulWeightPropagationScheduler): +class MatmulINT4MMAWeightPropagationScheduler(MatmulMMAWeightPropagationScheduler): - class TLHint(MatmulWeightPropagationScheduler.TLHint): - hint_type: str = "MatmulINT4WeightPropagationScheduler" + class TLHint(MatmulMMAWeightPropagationScheduler.TLHint): + hint_type: str = "MatmulINT4MMAWeightPropagationScheduler" def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" @@ -1252,328 +1020,3 @@ def __post_init__(self): assert self.trans_B is True, "Currently only support Matrix B transposed" return - - -def matmul_blocked( - M, - N, - K, - block_M=64, - block_N=64, - block_K=32, - trans_A=False, - trans_B=False, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - num_stages=2, - threads=128, - enable_rasterization: bool = False, # Enhance L2 Locality -): - A_shape = (K, M) if trans_A else (M, K) - B_shape = (N, K) if trans_B else (K, N) - C_shape = (M, N) - A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) - B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - - @T.prim_func - def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, in_dtype), - C: T.Buffer(C_shape, out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, in_dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - T.use_swizzle(10, enable=enable_rasterization) - - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - if trans_A: - T.copy(A[k * block_K, by * block_M], A_shared) - else: - T.copy(A[by * block_M, k * block_K], A_shared) - if trans_B: - T.copy(B[bx * block_N, k * block_K], B_shared) - else: - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - -def matmul_macro_tensorcore( - M, - N, - K, - in_dtype, - out_dtype, - trans_A, - trans_B, - accum_dtype, - block_row_warps, - block_col_warps, - warp_row_tiles, - warp_col_tiles, - chunk, - num_stages=2, - enable_rasterization: bool = False, -): - assert trans_A is False, "Currently only support Matrix A is not transposed" - assert trans_B is True, "Currently only support Matrix B is transposed" - - block_M = block_row_warps * warp_row_tiles - block_N = block_col_warps * warp_col_tiles - block_K = chunk - - micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) - - A_shape = (M, K) - B_shape = (N, K) - A_shared_shape = (block_M, block_K) - B_shared_shape = (block_N, block_K) - C_shared_shape = ( - block_M // micro_size_x, - block_N // micro_size_y, - micro_size_x, - micro_size_y, - ) - - warp_size = 32 # nvidia gpu warp size is 32 - threads = warp_size * (block_row_warps * block_col_warps) - local_size_a = (micro_size_x * micro_size_k) // warp_size - local_size_b = (micro_size_y * micro_size_k) // warp_size - local_size_c = (micro_size_x * micro_size_y) // warp_size - warp_rows = warp_row_tiles // micro_size_x - warp_cols = warp_col_tiles // micro_size_y - - shared_scope = "shared.dyn" # Literal["shared", "shared.dyn"] while shared for static shared memory - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - ) - - @T.prim_func - def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, in_dtype), - C: T.Buffer((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - thread_bindings = T.thread_binding(0, threads, "threadIdx.x") - - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) - - T.use_swizzle(panel_size=10, enable=enable_rasterization) - - T.clear(C_local) - - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] - - for j, k in T.Parallel(block_N, block_K): - B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] - - for ki in T.serial(0, (block_K // micro_size_k)): - - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_shared, - ki, - thread_bindings=thread_bindings, - ) - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_shared, - ki, - thread_bindings=thread_bindings, - ) - - mma_emitter.mma(A_local, B_local, C_local) - - mma_emitter.stmatrix( - C_local, - C_shared, - thread_bindings=thread_bindings, - ) - - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] - - return main - - -def matmul_macro_tensorcore_weight_propagation_level_ldmatrix( - M, - N, - K, - in_dtype, - out_dtype, - trans_A, - trans_B, - accum_dtype, - block_row_warps, - block_col_warps, - warp_row_tiles, - warp_col_tiles, - chunk, - num_stages=2, - enable_rasterization: bool = False, -): - assert trans_A is False, "Currently only support Matrix A is not transposed" - assert trans_B is True, "Currently only support Matrix B is transposed" - - block_M = block_row_warps * warp_row_tiles - block_N = block_col_warps * warp_col_tiles - block_K = chunk - - micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) - - A_shape = (M, K) - B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) - A_shared_shape = (block_M, block_K) - B_shared_shape = ( - block_N // micro_size_y, - block_K // micro_size_k, - micro_size_y, - micro_size_k, - ) - C_shared_shape = ( - block_M // micro_size_x, - block_N // micro_size_y, - micro_size_x, - micro_size_y, - ) - - warp_size = 32 # nvidia gpu warp size is 32 - threads = warp_size * (block_row_warps * block_col_warps) - local_size_a = (micro_size_x * micro_size_k) // warp_size - local_size_b = (micro_size_y * micro_size_k) // warp_size - local_size_c = (micro_size_x * micro_size_y) // warp_size - warp_rows = warp_row_tiles // micro_size_x - warp_cols = warp_col_tiles // micro_size_y - - shared_scope = "shared.dyn" # Literal["shared", "shared.dyn"] while shared for static shared memory - mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=trans_A, - b_transposed=trans_B, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - transform_kind_b=TransformKind.LDMatrixTransform, - ) - - @T.prim_func - def main( - A: T.Buffer(A_shape, in_dtype), - B: T.Buffer(B_shape, in_dtype), - C: T.Buffer((M, N), out_dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - - A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) - B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) - C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) - A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) - B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) - C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) - thread_bindings = T.thread_binding(0, threads, "threadIdx.x") - - T.annotate_layout({ - A_shared: make_swizzle_layout(A_shared), - B_shared: make_swizzle_layout(B_shared), - }) - - T.use_swizzle(panel_size=10, enable=enable_rasterization) - - T.clear(C_local) - - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - - for i, k in T.Parallel(block_M, block_K): - A_shared[i, k] = A[by * block_M + i, ko * block_K + k] - - for j, k, jj, kk in T.Parallel( - block_N // micro_size_y, - block_K // micro_size_k, - micro_size_y, - micro_size_k, - ): - B_shared[j, k, jj, kk] = B[ - bx * (block_N // micro_size_y) + j, - ko * (block_K // micro_size_k) + k, - jj, - kk, - ] - - for ki in T.serial(0, (block_K // micro_size_k)): - - # Load A into fragment - mma_emitter.ldmatrix_a( - A_local, - A_shared, - ki, - thread_bindings=thread_bindings, - ) - - # Load B into fragment - mma_emitter.ldmatrix_b( - B_local, - B_shared, - ki, - thread_bindings=thread_bindings, - ) - - mma_emitter.mma(A_local, B_local, C_local) - - mma_emitter.stmatrix( - C_local, - C_shared, - thread_bindings=thread_bindings, - ) - - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] - - return main diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py new file mode 100644 index 000000000..8dceefd73 --- /dev/null +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py @@ -0,0 +1,586 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# tile represents tile library + +from bitblas import tvm as tvm +import tvm.tl.language as T +from typing import Optional, List +from bitblas.tl.utils import ( + get_mma_micro_size, + make_mma_swizzle_layout as make_swizzle_layout, +) + +from bitblas.tl.mma_macro_generator import (TensorCoreIntrinEmitter, + TensorCoreIntrinEmitterWithLadderTransform) +from bitblas.base.operator_common import TransformKind +from bitblas.base.arch import TileDevice +from bitblas.base.roller.hint import Hint +from bitblas.base.roller.rasterization import NoRasterization +from bitblas.base.utils import get_roller_hints_from_func +from dataclasses import dataclass +from bitblas.ops.general_matmul.tirscript import (matmul_select_implementation) +from bitblas.tl.base_hint import BaseTLHint +from .base import MatmulBaseParams +# GPU warp configuration for NVIDIA GPUs +warp_size = 32 + + +@dataclass +class MatmulBaseScheduler(MatmulBaseParams): + + def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): + layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" + + # Simple TIR Compute Expression + ir_module = matmul_select_implementation( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.in_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + layout=layout, + ) + + roller_hints = get_roller_hints_from_func( + ir_module, + arch, + topk, + tensorcore_only=True, + allow_gemv=True, + ) + + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + + return self.serialize_hints_to_configs(roller_hints) + + def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): + if arch is None: + arch = self.arch + return self.get_roller_configs(arch, topk) + + # check if required shared memory cache + def check_require_cache(self) -> bool: + with_bias = self.with_bias + + conditions: List[bool] = [] + conditions.append(False) + # Bias Add should be performed in shared memory + conditions.append(with_bias) + return any(conditions) # Always set to False Currently + + +@dataclass +class MatmulTileLibraryScheduler(MatmulBaseScheduler): + + # Default Tile Related Params + block_M: int = 64 + block_N: int = 64 + block_K: int = 32 + num_stages: int = 2 + threads: int = 128 + enable_rasterization: bool = False # Enhance L2 Locality + + class TLHint(BaseTLHint): + + hint_type = "MatmulTileLibraryScheduler" + + def __init__(self): + super().__init__() + + @classmethod + def from_roller_hint(cls, hint: Hint): + tl_hint = cls() + for key, value in hint.__dict__.items(): + setattr(tl_hint, key, value) + + block = hint.block + warp = hint.warp + rstep = hint.rstep + num_stages = hint.pipeline_stage + rasterization_plan = hint.rasterization_plan + enable_rasterization = not isinstance(rasterization_plan, NoRasterization) + + block_row_warps = block[0] // warp[0] + block_col_warps = block[1] // warp[1] + warp_size = 32 # NVIDIA GPU warp size is 32 + if num_stages == 1: + num_stages = 0 # disable pipelining + + tl_hint.block_M = block[0] + tl_hint.block_N = block[1] + tl_hint.block_K = rstep[0] + tl_hint.num_stages = num_stages + tl_hint.threads = warp_size * block_row_warps * block_col_warps + tl_hint.enable_rasterization = enable_rasterization + + return tl_hint + + def get_config_params(self): + return { + "block_M": self.block_M, + "block_N": self.block_N, + "block_K": self.block_K, + "num_stages": self.num_stages, + "threads": self.threads, + "enable_rasterization": self.enable_rasterization, + } + + def __repr__(self): + return ("{" + f"block_M={self.block_M}," + f"block_N={self.block_N}," + f"block_K={self.block_K}," + f"num_stages={self.num_stages}," + f"threads={self.threads}," + f"enable_rasterization={self.enable_rasterization}" + "}") + + def get_configs_sm80(self): + num_stages = 2 + configs = [ + { + 'block_M': 128, + 'block_N': 256, + 'block_K': 32, + 'threads': 128 + }, + { + 'block_M': 256, + 'block_N': 128, + 'block_K': 32, + 'threads': 128 + }, + { + 'block_M': 128, + 'block_N': 128, + 'block_K': 32, + 'threads': 128 + }, + ] + configs = [{**c, 'num_stages': num_stages} for c in configs] + return configs + + def get_hint_type(self): + return self.TLHint.hint_type + + def serialize_hints_to_configs(self, hints: List[Hint]): + configs = [] + for hint in hints: + config = self.TLHint.from_roller_hint(hint) + configs.append(config) + return configs + + def with_default_config(self): + block_M = getattr(self, "block_M", 64) + block_N = getattr(self, "block_N", 64) + block_K = getattr(self, "block_K", 32) + num_stages = getattr(self, "num_stages", 2) + threads = getattr(self, "threads", 128) + enable_rasterization = getattr(self, "enable_rasterization", False) + + return self.apply_config( + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + enable_rasterization=enable_rasterization, + ) + + def apply_config( + self, + block_M: Optional[int] = None, + block_N: Optional[int] = None, + block_K: Optional[int] = None, + num_stages: Optional[int] = None, + threads: Optional[int] = None, + # Enhance L2 Locality + enable_rasterization: bool = False, + ): + assert block_M is not None, "block_M is required" + assert block_N is not None, "block_N is required" + assert block_K is not None, "block_K is required" + assert num_stages is not None, "num_stages is required" + assert threads is not None, "threads is required" + + M = self.maybe_dynamic(self.M, "m") + N, K = self.N, self.K + assert isinstance(N, int) and isinstance(K, int), "Do not support dynamic N and K Currently" + + trans_A, trans_B = self.trans_A, self.trans_B + in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + with_bias = self.with_bias + + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + C_shape = (M, N) + Bias_shape = (N,) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + Bias: T.Buffer(Bias_shape, out_dtype), + C: T.Buffer(C_shape, out_dtype), + ): + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.use_swizzle(10, enable=enable_rasterization) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + + if with_bias: + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] += Bias[bx * block_N + j] + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return self.post_process(main) + + def __post_init__(self): + # Add Config Validation + return + + +# TODO(lei): remove these legacy functions in the future +def matmul_blocked( + M, + N, + K, + block_M=64, + block_N=64, + block_K=32, + trans_A=False, + trans_B=False, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + num_stages=2, + threads=128, + enable_rasterization: bool = False, # Enhance L2 Locality +): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + C_shape = (M, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer(C_shape, out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.use_swizzle(10, enable=enable_rasterization) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + if trans_A: + T.copy(A[k * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, k * block_K], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def matmul_macro_tensorcore( + M, + N, + K, + in_dtype, + out_dtype, + trans_A, + trans_B, + accum_dtype, + block_row_warps, + block_col_warps, + warp_row_tiles, + warp_col_tiles, + chunk, + num_stages=2, + enable_rasterization: bool = False, +): + assert trans_A is False, "Currently only support Matrix A is not transposed" + assert trans_B is True, "Currently only support Matrix B is transposed" + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) + + A_shape = (M, K) + B_shape = (N, K) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_N, block_K) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 # nvidia gpu warp size is 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + shared_scope = "shared.dyn" # Literal["shared", "shared.dyn"] while shared for static shared memory + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + T.use_swizzle(panel_size=10, enable=enable_rasterization) + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + for j, k in T.Parallel(block_N, block_K): + B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + mma_emitter.mma(A_local, B_local, C_local) + + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main + + +def matmul_macro_tensorcore_weight_propagation_level_ldmatrix( + M, + N, + K, + in_dtype, + out_dtype, + trans_A, + trans_B, + accum_dtype, + block_row_warps, + block_col_warps, + warp_row_tiles, + warp_col_tiles, + chunk, + num_stages=2, + enable_rasterization: bool = False, +): + assert trans_A is False, "Currently only support Matrix A is not transposed" + assert trans_B is True, "Currently only support Matrix B is transposed" + + block_M = block_row_warps * warp_row_tiles + block_N = block_col_warps * warp_col_tiles + block_K = chunk + + micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(in_dtype) + + A_shape = (M, K) + B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k) + A_shared_shape = (block_M, block_K) + B_shared_shape = ( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ) + C_shared_shape = ( + block_M // micro_size_x, + block_N // micro_size_y, + micro_size_x, + micro_size_y, + ) + + warp_size = 32 # nvidia gpu warp size is 32 + threads = warp_size * (block_row_warps * block_col_warps) + local_size_a = (micro_size_x * micro_size_k) // warp_size + local_size_b = (micro_size_y * micro_size_k) // warp_size + local_size_c = (micro_size_x * micro_size_y) // warp_size + warp_rows = warp_row_tiles // micro_size_x + warp_cols = warp_col_tiles // micro_size_y + + shared_scope = "shared.dyn" # Literal["shared", "shared.dyn"] while shared for static shared memory + mma_emitter = TensorCoreIntrinEmitterWithLadderTransform( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=trans_A, + b_transposed=trans_B, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + transform_kind_b=TransformKind.LDMatrixTransform, + ) + + @T.prim_func + def main( + A: T.Buffer(A_shape, in_dtype), + B: T.Buffer(B_shape, in_dtype), + C: T.Buffer((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) + thread_bindings = T.thread_binding(0, threads, "threadIdx.x") + + T.annotate_layout({ + A_shared: make_swizzle_layout(A_shared), + B_shared: make_swizzle_layout(B_shared), + }) + + T.use_swizzle(panel_size=10, enable=enable_rasterization) + + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + + for i, k in T.Parallel(block_M, block_K): + A_shared[i, k] = A[by * block_M + i, ko * block_K + k] + + for j, k, jj, kk in T.Parallel( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + micro_size_k, + ): + B_shared[j, k, jj, kk] = B[ + bx * (block_N // micro_size_y) + j, + ko * (block_K // micro_size_k) + k, + jj, + kk, + ] + + for ki in T.serial(0, (block_K // micro_size_k)): + + # Load A into fragment + mma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + thread_bindings=thread_bindings, + ) + + # Load B into fragment + mma_emitter.ldmatrix_b( + B_local, + B_shared, + ki, + thread_bindings=thread_bindings, + ) + + mma_emitter.mma(A_local, B_local, C_local) + + mma_emitter.stmatrix( + C_local, + C_shared, + thread_bindings=thread_bindings, + ) + + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + + return main diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_wmma.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_wmma.py index 447e7e47b..a7a307747 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_wmma.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_wmma.py @@ -6,7 +6,7 @@ from bitblas.base.roller.rasterization import NoRasterization from dataclasses import dataclass from bitblas.tl.base_hint import BaseTLHint -from .matmul_tensorcore import MatmulBaseScheduler +from .matmul_tile import MatmulBaseScheduler # GPU warp configuration for NVIDIA GPUs warp_size = 32 @@ -14,7 +14,7 @@ # TODO(lei): This is not implemented in the current version of the codebase @dataclass -class MatmulFineGrainScheduler(MatmulBaseScheduler): +class MatmulMMAScheduler(MatmulBaseScheduler): # Fine-grained matrix multiplication scheduler # Allows for more detailed configuration. @@ -31,7 +31,7 @@ class MatmulFineGrainScheduler(MatmulBaseScheduler): class TLHint(BaseTLHint): - hint_type: str = "MatmulFineGrainScheduler" + hint_type: str = "MatmulMMAScheduler" def __init__(self): super().__init__() diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py b/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py index 494a988d7..2c36e8bad 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/__init__.py @@ -1,18 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .matmul_dequantize_tensorcore import ( - MatmulDequantizeBlockScheduler, # noqa: F401 +from .matmul_dequantize_tile import ( + MatmulDequantizeTileLibraryScheduler, # noqa: F401 ) -from .matmul_dequantize_tensorcore_finegrained import ( - MatmulDequantizeFineGrainedScheduler, # noqa: F401 - MatmulINT4DequantizeFineGrainedScheduler, # noqa: F401 +from .matmul_dequantize_mma import ( + MatmulDequantizeMMAScheduler, # noqa: F401 + MatmulINT4DequantizeMMAScheduler, # noqa: F401 ) -from .matmul_dequantize_tensorcore_weight_transform import ( - MatmulDequantizeWeightPropagationScheduler, # noqa: F401 - MatmulINT4DequantizeWeightPropagationScheduler, # noqa: F401 +from .matmul_dequantize_mma_weight_transform import ( + MatmulDequantizeMMAWeightPropagationScheduler, # noqa: F401 + MatmulINT4DequantizeMMAWeightPropagationScheduler, # noqa: F401 ) from .matmul_dequantize import MatmulDequantizeScheduler diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py index 9716ac075..503453760 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize.py @@ -17,14 +17,14 @@ from .base import MatmulDequantizeBaseParams from .gemv_dequantize_simt import GemvDequantizeSIMTScheduler from .matmul_dequantize_simt import MatmulDequantizeSIMTScheduler -from .matmul_dequantize_tensorcore import MatmulDequantizeBlockScheduler -from .matmul_dequantize_tensorcore_finegrained import ( - MatmulDequantizeFineGrainedScheduler, - MatmulINT4DequantizeFineGrainedScheduler, +from .matmul_dequantize_tile import MatmulDequantizeTileLibraryScheduler +from .matmul_dequantize_mma import ( + MatmulDequantizeMMAScheduler, + MatmulINT4DequantizeMMAScheduler, ) -from .matmul_dequantize_tensorcore_weight_transform import ( - MatmulDequantizeWeightPropagationScheduler, - MatmulINT4DequantizeWeightPropagationScheduler, +from .matmul_dequantize_mma_weight_transform import ( + MatmulDequantizeMMAWeightPropagationScheduler, + MatmulINT4DequantizeMMAWeightPropagationScheduler, ) import logging @@ -38,26 +38,24 @@ class MatmulDequantizeScheduler(MatmulDequantizeBaseParams): # Allows for more detailed configuration. gemv_dequantize_simt_scheduler: Optional[GemvDequantizeSIMTScheduler] = None matmul_dequantize_simt_scheduler: Optional[MatmulDequantizeSIMTScheduler] = None - matmul_dequantize_block_scheduler: Optional[MatmulDequantizeBlockScheduler] = None - matmul_dequantize_fine_grained_scheduler: Optional[MatmulDequantizeFineGrainedScheduler] = None + matmul_dequantize_block_scheduler: Optional[MatmulDequantizeTileLibraryScheduler] = None + matmul_dequantize_fine_grained_scheduler: Optional[MatmulDequantizeMMAScheduler] = None matmul_dequantize_weight_propagation_scheduler: Optional[ - MatmulDequantizeWeightPropagationScheduler] = None - matmul_int4_dequantize_fine_grain_scheduler: Optional[ - MatmulINT4DequantizeFineGrainedScheduler] = None + MatmulDequantizeMMAWeightPropagationScheduler] = None + matmul_int4_dequantize_fine_grain_scheduler: Optional[MatmulINT4DequantizeMMAScheduler] = None matmul_int4_dequantize_weight_propagation_scheduler: Optional[ - MatmulINT4DequantizeWeightPropagationScheduler] = None + MatmulINT4DequantizeMMAWeightPropagationScheduler] = None def __init__(self, **kwargs): self.gemv_dequantize_simt_scheduler = GemvDequantizeSIMTScheduler(**kwargs) self.matmul_dequantize_simt_scheduler = MatmulDequantizeSIMTScheduler(**kwargs) - self.matmul_dequantize_block_scheduler = MatmulDequantizeBlockScheduler(**kwargs) - self.matmul_dequantize_fine_grained_scheduler = MatmulDequantizeFineGrainedScheduler( + self.matmul_dequantize_block_scheduler = MatmulDequantizeTileLibraryScheduler(**kwargs) + self.matmul_dequantize_fine_grained_scheduler = MatmulDequantizeMMAScheduler(**kwargs) + self.matmul_dequantize_weight_propagation_scheduler = MatmulDequantizeMMAWeightPropagationScheduler( **kwargs) - self.matmul_dequantize_weight_propagation_scheduler = MatmulDequantizeWeightPropagationScheduler( + self.matmul_int4_dequantize_fine_grain_scheduler = MatmulINT4DequantizeMMAScheduler( **kwargs) - self.matmul_int4_dequantize_fine_grain_scheduler = MatmulINT4DequantizeFineGrainedScheduler( - **kwargs) - self.matmul_int4_dequantize_weight_propagation_scheduler = MatmulINT4DequantizeWeightPropagationScheduler( + self.matmul_int4_dequantize_weight_propagation_scheduler = MatmulINT4DequantizeMMAWeightPropagationScheduler( **kwargs) super().__init__(**kwargs) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py similarity index 98% rename from bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py rename to bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py index a9fa86125..f651d8c0b 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py @@ -17,7 +17,7 @@ from bitblas.base.roller.rasterization import NoRasterization from bitblas.base.utils import get_roller_hints_from_func from dataclasses import dataclass -from bitblas.ops.general_matmul.tilelang.dequantize.matmul_dequantize_tensorcore import ( +from bitblas.ops.general_matmul.tilelang.dequantize.matmul_dequantize_tile import ( MatmulDequantizeBaseScheduler, # noqa: F401 ) from bitblas.tl.base_hint import BaseTLHint @@ -27,7 +27,7 @@ @dataclass -class MatmulDequantizeFineGrainedScheduler(MatmulDequantizeBaseScheduler): +class MatmulDequantizeMMAScheduler(MatmulDequantizeBaseScheduler): # Tensor Core Warp Configuration block_row_warps: int = 2 @@ -43,7 +43,7 @@ class MatmulDequantizeFineGrainedScheduler(MatmulDequantizeBaseScheduler): class TLHint(BaseTLHint): - hint_type: str = "MatmulDequantizeFineGrainedScheduler" + hint_type: str = "MatmulDequantizeMMAScheduler" def __init__(self): super().__init__() @@ -508,10 +508,10 @@ def general_dequant_matmul( @dataclass -class MatmulINT4DequantizeFineGrainedScheduler(MatmulDequantizeFineGrainedScheduler): +class MatmulINT4DequantizeMMAScheduler(MatmulDequantizeMMAScheduler): - class TLHint(MatmulDequantizeFineGrainedScheduler.TLHint): - hint_type: str = "MatmulINT4DequantizeFineGrainedScheduler" + class TLHint(MatmulDequantizeMMAScheduler.TLHint): + hint_type: str = "MatmulINT4DequantizeMMAScheduler" def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py similarity index 98% rename from bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py rename to bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py index 56d052eda..6e24ab2da 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py @@ -11,7 +11,7 @@ ) from bitblas.base.arch import TileDevice from bitblas.base.roller.hint import Hint -from .matmul_dequantize_tensorcore_finegrained import MatmulDequantizeFineGrainedScheduler +from .matmul_dequantize_mma import MatmulDequantizeMMAScheduler from bitblas.tl.mma_macro_generator import ( TensorCoreIntrinEmitterWithLadderTransform, INT4TensorCoreIntrinEmitterWithLadderTransform, @@ -33,13 +33,13 @@ @dataclass -class MatmulDequantizeWeightPropagationScheduler(MatmulDequantizeFineGrainedScheduler): +class MatmulDequantizeMMAWeightPropagationScheduler(MatmulDequantizeMMAScheduler): # force set default weight transform kind to LDMatrixTransform weight_transform_kind: TransformKind = TransformKind.LDMatrixTransform - class TLHint(MatmulDequantizeFineGrainedScheduler.TLHint): - hint_type: str = "MatmulDequantizeWeightPropagationScheduler" + class TLHint(MatmulDequantizeMMAScheduler.TLHint): + hint_type: str = "MatmulDequantizeMMAWeightPropagationScheduler" def apply_config( self, @@ -680,10 +680,10 @@ def is_b_smooth(self): @dataclass -class MatmulINT4DequantizeWeightPropagationScheduler(MatmulDequantizeWeightPropagationScheduler): +class MatmulINT4DequantizeMMAWeightPropagationScheduler(MatmulDequantizeMMAWeightPropagationScheduler): - class TLHint(MatmulDequantizeWeightPropagationScheduler.TLHint): - hint_type: str = "MatmulINT4DequantizeWeightPropagationScheduler" + class TLHint(MatmulDequantizeMMAWeightPropagationScheduler.TLHint): + hint_type: str = "MatmulINT4DequantizeMMAWeightPropagationScheduler" def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}" diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tile.py similarity index 99% rename from bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py rename to bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tile.py index 847eb49fd..c605ad9d9 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tile.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# Tile represents Tile Library + from bitblas import tvm as tvm from tvm import DataType import tvm.tl.language as T @@ -452,7 +454,7 @@ def num_elems_per_byte(self): @dataclass -class MatmulDequantizeBlockScheduler(MatmulDequantizeBaseScheduler): +class MatmulDequantizeTileLibraryScheduler(MatmulDequantizeBaseScheduler): # Default Tile Related Params block_M: int = 128 @@ -463,7 +465,7 @@ class MatmulDequantizeBlockScheduler(MatmulDequantizeBaseScheduler): enable_rasterization: bool = False # Enhance L2 Locality class TLHint(BaseTLHint): - hint_type: str = "MatmulDequantizeBlockScheduler" + hint_type: str = "MatmulDequantizeTileLibraryScheduler" def __init__(self): super().__init__() diff --git a/testing/python/operators/test_general_matmul_fp8.py b/testing/python/operators/test_general_matmul_fp8.py index b21cdc8ca..0ef2e64af 100644 --- a/testing/python/operators/test_general_matmul_fp8.py +++ b/testing/python/operators/test_general_matmul_fp8.py @@ -27,7 +27,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo with_zeros=with_zeros, zeros_mode=zeros_mode, ) - matmul = Matmul(config=matmul_config, enable_tuning=True) + matmul = Matmul(config=matmul_config, enable_tuning=True, backend="tir") input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) @@ -93,7 +93,7 @@ def matmul_torch_forward_weight_dequantize(M, N, K, A_dtype, W_dtype, accum_dtyp propagate_a=False, propagate_b=False, ) - matmul = Matmul(config=matmul_config, enable_tuning=False) + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tir") input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) diff --git a/testing/python/operators/test_general_matmul_tilelang_impl.py b/testing/python/operators/test_general_matmul_tilelang_impl.py index e412e2298..b05d10390 100644 --- a/testing/python/operators/test_general_matmul_tilelang_impl.py +++ b/testing/python/operators/test_general_matmul_tilelang_impl.py @@ -4,7 +4,7 @@ from bitblas import tvm as tvm import bitblas.testing from tvm import tl -from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import ( +from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( matmul_blocked, matmul_macro_tensorcore, matmul_macro_tensorcore_weight_propagation_level_ldmatrix, diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index e89701af8..b789bd4e1 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -4,23 +4,22 @@ from bitblas import tvm as tvm import bitblas.testing from tvm import tl -from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import ( - MatmulBlockScheduler, - MatmulFineGrainScheduler, - MatmulWeightPropagationScheduler, -) +from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( + MatmulTileLibraryScheduler,) from bitblas.ops.general_matmul.tilelang.dequantize import ( MatmulDequantizeScheduler, - MatmulDequantizeFineGrainedScheduler, - MatmulDequantizeWeightPropagationScheduler, - MatmulINT4DequantizeFineGrainedScheduler, - MatmulINT4DequantizeWeightPropagationScheduler, + MatmulDequantizeMMAScheduler, + MatmulDequantizeMMAWeightPropagationScheduler, + MatmulINT4DequantizeMMAScheduler, + MatmulINT4DequantizeMMAWeightPropagationScheduler, ) -from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import ( - MatmulINT4FineGrainScheduler, - MatmulINT4WeightPropagationScheduler, +from bitblas.ops.general_matmul.tilelang.dense.matmul_mma import ( + MatmulMMAScheduler, + MatmulMMAWeightPropagationScheduler, + MatmulINT4MMAScheduler, + MatmulINT4MMAWeightPropagationScheduler, ) import torch @@ -41,7 +40,7 @@ def assert_matmul_blocked_with_default_correctness( out_dtype="float16", accum_dtype="float16", ): - matmul = MatmulBlockScheduler( + matmul = MatmulTileLibraryScheduler( M=M, N=N, K=K, @@ -92,7 +91,7 @@ def assert_matmul_blocked_apply_config_correctness( threads=128, enable_rasterization: bool = False, ): - matmul = MatmulBlockScheduler( + matmul = MatmulTileLibraryScheduler( M=M, N=N, K=K, @@ -145,7 +144,7 @@ def assert_matmul_fine_grained_with_default_correctness( accum_dtype="float16", ): - matmul = MatmulFineGrainScheduler( + matmul = MatmulMMAScheduler( M=M, N=N, K=K, @@ -199,7 +198,7 @@ def assert_matmul_fine_grained_apply_config_correctness( enable_rasterization: bool = False, ): - matmul = MatmulFineGrainScheduler( + matmul = MatmulMMAScheduler( M=M, N=N, K=K, @@ -253,7 +252,7 @@ def assert_matmul_weight_propagation_with_default_correctness( accum_dtype="float16", ): - matmul = MatmulWeightPropagationScheduler( + matmul = MatmulMMAWeightPropagationScheduler( M=M, N=N, K=K, @@ -319,7 +318,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( enable_rasterization: bool = False, ): - matmul = MatmulWeightPropagationScheduler( + matmul = MatmulMMAWeightPropagationScheduler( M=M, N=N, K=K, @@ -384,7 +383,7 @@ def assert_matmul_int4_fine_grained_with_default_correctness( accum_dtype="int32", ): - matmul = MatmulINT4FineGrainScheduler( + matmul = MatmulINT4MMAScheduler( M=M, N=N, K=K, @@ -441,7 +440,7 @@ def assert_matmul_int4_fine_grained_apply_config_correctness( enable_rasterization: bool = False, ): - matmul = MatmulINT4FineGrainScheduler( + matmul = MatmulINT4MMAScheduler( M=M, N=N, K=K, @@ -499,7 +498,7 @@ def assert_matmul_int4_weight_propagation_with_default_correctness( accum_dtype="int32", ): - matmul = MatmulINT4WeightPropagationScheduler( + matmul = MatmulINT4MMAWeightPropagationScheduler( M=M, N=N, K=K, @@ -569,7 +568,7 @@ def assert_matmul_int4_weight_propagation_apply_config__correctness( enable_rasterization: bool = False, ): - matmul = MatmulINT4WeightPropagationScheduler( + matmul = MatmulINT4MMAWeightPropagationScheduler( M=M, N=N, K=K, @@ -648,7 +647,7 @@ def assert_matmul_fine_grained_dequant_int4_with_default_correctness( fast_decoding=False, zeros_mode="original", ): - matmul = MatmulINT4DequantizeFineGrainedScheduler( + matmul = MatmulINT4DequantizeMMAScheduler( M=M, N=N, K=K, @@ -739,7 +738,7 @@ def assert_matmul_fine_grained_dequant_int4_apply_config_correctness( num_stages=2, enable_rasterization: bool = False, ): - matmul = MatmulINT4DequantizeFineGrainedScheduler( + matmul = MatmulINT4DequantizeMMAScheduler( M=M, N=N, K=K, @@ -831,7 +830,7 @@ def assert_matmul_weight_transform_dequant_int4_with_default_correctness( fast_decoding=False, zeros_mode="original", ): - matmul = MatmulINT4DequantizeWeightPropagationScheduler( + matmul = MatmulINT4DequantizeMMAWeightPropagationScheduler( M=M, N=N, K=K, @@ -943,7 +942,7 @@ def assert_matmul_weight_transform_dequant_int4_apply_config_correctness( num_stages=2, enable_rasterization: bool = False, ): - matmul = MatmulINT4DequantizeWeightPropagationScheduler( + matmul = MatmulINT4DequantizeMMAWeightPropagationScheduler( M=M, N=N, K=K, @@ -1195,7 +1194,7 @@ def assert_matmul_fine_grained_dequant_with_default_correctness( import numpy as np from bitblas.quantization import general_compress, interleave_weight - matmul = MatmulDequantizeFineGrainedScheduler( + matmul = MatmulDequantizeMMAScheduler( M=M, N=N, K=K, @@ -1325,7 +1324,7 @@ def assert_matmul_weight_transform_dequant_with_default_correctness( import numpy as np from bitblas.quantization import general_compress, interleave_weight - matmul = MatmulDequantizeWeightPropagationScheduler( + matmul = MatmulDequantizeMMAWeightPropagationScheduler( M=M, N=N, K=K, diff --git a/testing/python/operators/test_general_matmul_tilelang_scheduler.py b/testing/python/operators/test_general_matmul_tilelang_scheduler.py index f5b85d409..83428b7a1 100644 --- a/testing/python/operators/test_general_matmul_tilelang_scheduler.py +++ b/testing/python/operators/test_general_matmul_tilelang_scheduler.py @@ -4,8 +4,8 @@ from bitblas import tvm as tvm import bitblas.testing from tvm.ir import structural_equal -from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import ( - MatmulBlockScheduler,) +from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( + MatmulTileLibraryScheduler,) from bitblas.ops.general_matmul.tilelang.dequantize import (MatmulDequantizeScheduler) from bitblas.ops.general_matmul.tilelang.dense.gemv_simt import GemvFineGrainSIMTScheduler from bitblas.ops.general_matmul.tilelang.dense import MatmulScheduler @@ -49,7 +49,7 @@ def assert_dense_scheduler_simplify(M, in_dtype="float16", out_dtype="float16", accum_dtype="float16"): - matmul = MatmulBlockScheduler( + matmul = MatmulTileLibraryScheduler( M=M, N=N, K=K, @@ -60,7 +60,7 @@ def assert_dense_scheduler_simplify(M, accum_dtype=accum_dtype, ).deactivate_simplify().with_default_config() - simplified = MatmulBlockScheduler.Simplify(matmul) + simplified = MatmulTileLibraryScheduler.Simplify(matmul) is_equal = structural_equal(matmul, simplified) if is_equal: