Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a1e7c81
Refactor: Rename tensor core related classes and imports to use tile …
LeiWang1999 Feb 2, 2025
566fc67
lint fix
LeiWang1999 Feb 2, 2025
563a704
Fix: Update Matmul initialization to specify backend and clean up imp…
LeiWang1999 Feb 2, 2025
d95fb42
lint fix
LeiWang1999 Feb 2, 2025
97ecf57
Update subproject commit for TVM dependency
LeiWang1999 Feb 2, 2025
fe44dd7
Refactor: Update imports to use tilelang instead of tvm.tl.language
LeiWang1999 Feb 3, 2025
7dc068c
Refactor: Clean up import statements and formatting in bitblas module
LeiWang1999 Feb 3, 2025
691a0dc
Fix: Add newline injection to .bashrc if the last line is not empty i…
LeiWang1999 Feb 3, 2025
fd4c1a6
Update submodule URLs and branches for TVM and TileLang
LeiWang1999 Feb 3, 2025
e7ad6a9
Update tilelang submodule URL and add new subproject commit
LeiWang1999 Feb 3, 2025
6da2f5f
Update cutlass submodule URL to point to tile-ai repository
LeiWang1999 Feb 3, 2025
30f1a95
Merge branch 'main' of https://github.com/microsoft/BitBLAS into sepa…
LeiWang1999 Feb 3, 2025
8f68896
Refactor: Split class definition for MatmulINT4DequantizeMMAWeightPro…
LeiWang1999 Feb 3, 2025
af0a134
Enhance environment variable handling for TVM and TileLang paths in i…
LeiWang1999 Feb 3, 2025
b9bb657
Remove unnecessary blank line in initialization of TILELANG_IMPORT_PA…
LeiWang1999 Feb 3, 2025
9f7d4c6
Add build_tilelang function to setup.py for TILELANG integration
LeiWang1999 Feb 4, 2025
1cc6886
Add TILELANG build step in setup.py
LeiWang1999 Feb 4, 2025
8a53cd4
Update TileLang subproject and improve type mappings in TLCUDASourceW…
LeiWang1999 Feb 4, 2025
3cac227
Refactor line continuation for better readability in gemv_simt.py
LeiWang1999 Feb 4, 2025
7ce4860
Merge branch 'main' of https://github.com/microsoft/BitBLAS into sepa…
LeiWang1999 Feb 4, 2025
494434e
Update TileLang subproject to latest commit
LeiWang1999 Feb 4, 2025
7bf057b
Merge branch 'main' of https://github.com/microsoft/BitBLAS into sepa…
LeiWang1999 Feb 4, 2025
6d8e05d
Update subproject commit for TVM to the latest version
LeiWang1999 Feb 4, 2025
34d77e4
Update TileLang subproject and fix CUDA architecture checks
LeiWang1999 Feb 5, 2025
895080e
Disable tuning for Matmul in FP8 tests and update backend configuration
LeiWang1999 Feb 5, 2025
ab990ef
Update TileLang subproject to latest commit
LeiWang1999 Feb 6, 2025
9d4bcc9
Refactor code for improved readability by adjusting line breaks and f…
LeiWang1999 Feb 6, 2025
b6edc52
Update float16 type mapping to use 'half_t' in TLCUDASourceWrapper
LeiWang1999 Feb 7, 2025
167c82b
Update TileLang subproject and improve error logging in tuner.py
LeiWang1999 Feb 7, 2025
9b86b6d
Remove unnecessary whitespace in tuner.py for cleaner code
LeiWang1999 Feb 8, 2025
4fd0670
Update bfloat16 type mapping to use 'bfloat16_t' in TLCUDASourceWrapper
LeiWang1999 Feb 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
2 changes: 1 addition & 1 deletion bitblas/base/arch/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def is_volta_arch(arch: TileDevice) -> bool:
def is_ampere_arch(arch: TileDevice) -> bool:
conditions = [True]
conditions.append(is_cuda_arch(arch))
conditions.append(arch.sm_version >= 80 and arch.sm_version < 90)
conditions.append(arch.sm_version >= 80 and arch.sm_version < 89)
return all(conditions)


Expand Down
4 changes: 2 additions & 2 deletions bitblas/builder/wrapper/tl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
class TLCUDASourceWrapper(object):
_TYPE_MAP = {
"float32": "float",
"float16": "half",
"bfloat16": "__nv_bfloat16",
"float16": "half_t",
"bfloat16": "bfloat16_t",
"e4m3_float8": "__nv_fp8_e4m3",
"e5m2_float8": "__nv_fp8_e5m2",
"float64": "double",
Expand Down
4 changes: 3 additions & 1 deletion bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,9 @@ def check_last_trait(region: List[Range]):
if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70:
in_dtype, out_dtype = get_in_out_dtypes(block_stmt)
if not is_tensorcore_supported_precision(in_dtype, out_dtype, arch=get_arch(target)):
logger.debug("The input and output dtype is not supported by tensorcore")
logger.debug(
f"The input and output dtype ({in_dtype}, {out_dtype})is not supported by tensorcore"
)
return func, None

# reindex and transform functions
Expand Down
8 changes: 4 additions & 4 deletions bitblas/ops/general_matmul/tilelang/dense/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def ampere_select_scheduler(

trans_A, trans_B = parse_layout(layout)

def can_apply_fine_grain_scheduler(trans_A, trans_B, propagate_a, propagate_b):
def can_apply_mma_scheduler(trans_A, trans_B, propagate_a, propagate_b):
conditions = []
conditions.append(trans_A is False)
conditions.append(trans_B is True)
Expand All @@ -116,7 +116,7 @@ def can_apply_block_scheduler(propagate_a, propagate_b):
conditions.append(propagate_b == TransformKind.NonTransform)
return all(conditions)

def can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b):
def can_apply_mma_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b):
conditions = []
conditions.append(trans_A is False)
conditions.append(trans_B is True)
Expand All @@ -127,7 +127,7 @@ def can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propag
def is_int4_dtype(dtype):
return dtype == "int4" or dtype == "uint4"

if can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b):
if can_apply_mma_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b):
Scheduler = MatmulMMAWeightPropagationScheduler if not is_int4_dtype(
in_dtype) else MatmulINT4MMAWeightPropagationScheduler
return Scheduler(
Expand All @@ -141,7 +141,7 @@ def is_int4_dtype(dtype):
accum_dtype=accum_dtype,
with_bias=with_bias,
)
if can_apply_fine_grain_scheduler(trans_A, trans_B, propagate_a, propagate_b):
if can_apply_mma_scheduler(trans_A, trans_B, propagate_a, propagate_b):
Scheduler = MatmulMMAScheduler if not is_int4_dtype(in_dtype) else MatmulINT4MMAScheduler
return Scheduler(
M=M,
Expand Down
64 changes: 35 additions & 29 deletions bitblas/ops/general_matmul/tilelang/dense/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
TileDevice,
is_ampere_arch,
is_volta_arch,
is_ada_arch,
is_hopper_arch,
is_tensorcore_supported_precision,
)
from tilelang.intrinsics.utils import get_mma_micro_size
from dataclasses import dataclass
from bitblas.tl.base_hint import BaseTLHint

Expand Down Expand Up @@ -39,20 +42,20 @@ class MatmulScheduler(MatmulBaseParams):
gemv_scheduler: Optional[GemvFineGrainSIMTScheduler] = None
matmul_simt_scheduler: Optional[MatmulFineGrainSIMTScheduler] = None
matmul_block_scheduler: Optional[MatmulTileLibraryScheduler] = None
matmul_fine_grain_scheduler: Optional[MatmulMMAScheduler] = None
matmul_weight_propagation_scheduler: Optional[MatmulMMAWeightPropagationScheduler] = None
matmul_int4_fine_grain_scheduler: Optional[MatmulINT4MMAScheduler] = None
matmul_int4_weight_propagation_scheduler: Optional[
matmul_mma_scheduler: Optional[MatmulMMAScheduler] = None
matmul_mma_weight_propagation_scheduler: Optional[MatmulMMAWeightPropagationScheduler] = None
matmul_int4_mma_scheduler: Optional[MatmulINT4MMAScheduler] = None
matmul_int4_mma_weight_propagation_scheduler: Optional[
MatmulINT4MMAWeightPropagationScheduler] = None

def __init__(self, **kwargs):
self.gemv_scheduler = GemvFineGrainSIMTScheduler(**kwargs)
self.matmul_simt_scheduler = MatmulFineGrainSIMTScheduler(**kwargs)
self.matmul_block_scheduler = MatmulTileLibraryScheduler(**kwargs)
self.matmul_fine_grain_scheduler = MatmulMMAScheduler(**kwargs)
self.matmul_weight_propagation_scheduler = MatmulMMAWeightPropagationScheduler(**kwargs)
self.matmul_int4_fine_grain_scheduler = MatmulINT4MMAScheduler(**kwargs)
self.matmul_int4_weight_propagation_scheduler = MatmulINT4MMAWeightPropagationScheduler(
self.matmul_mma_scheduler = MatmulMMAScheduler(**kwargs)
self.matmul_mma_weight_propagation_scheduler = MatmulMMAWeightPropagationScheduler(**kwargs)
self.matmul_int4_mma_scheduler = MatmulINT4MMAScheduler(**kwargs)
self.matmul_int4_mma_weight_propagation_scheduler = MatmulINT4MMAWeightPropagationScheduler(
**kwargs)
super().__init__(**kwargs)

Expand All @@ -72,14 +75,13 @@ def dispatch_ampere_scheduler(self, arch: TileDevice) -> BaseScheduler:
if is_tensorcore_supported_precision(in_dtype, accum_dtype, arch):
if weight_transform_kind != TransformKind.NonTransform:
# INT4 Can be fused into general dequantize
return self.matmul_int4_weight_propagation_scheduler if in_dtype == "int4" else self.matmul_weight_propagation_scheduler
return self.matmul_int4_fine_grain_scheduler if in_dtype == "int4" else self.matmul_fine_grain_scheduler
return self.matmul_int4_mma_weight_propagation_scheduler if in_dtype == "int4" else self.matmul_mma_weight_propagation_scheduler
return self.matmul_int4_mma_scheduler if in_dtype == "int4" else self.matmul_mma_scheduler
else:
return self.matmul_simt_scheduler
else:
minimal_tensorcore_threshold: List[int, int,
int] = [8, 16, 32
] if accum_dtype == "int32" else [8, 16, 16]
_, _, micro_size_k = get_mma_micro_size(in_dtype)
minimal_tensorcore_threshold: List[int, int, int] = [8, 16, micro_size_k]
if minimal_tensorcore_threshold[0] > M or minimal_tensorcore_threshold[
1] > N or minimal_tensorcore_threshold[2] > K:
if in_dtype == "int4":
Expand All @@ -90,10 +92,11 @@ def dispatch_ampere_scheduler(self, arch: TileDevice) -> BaseScheduler:
return self.gemv_scheduler
elif is_tensorcore_supported_precision(in_dtype, accum_dtype, arch):
if self.weight_transform_kind != TransformKind.NonTransform:
return (self.matmul_int4_weight_propagation_scheduler
if in_dtype == "int4" else self.matmul_weight_propagation_scheduler)
return (self.matmul_int4_mma_weight_propagation_scheduler
if in_dtype == "int4" else self.matmul_mma_weight_propagation_scheduler)
else:
return self.matmul_int4_fine_grain_scheduler if in_dtype == "int4" else self.matmul_block_scheduler
# by default, use the mma_scheduler
return self.matmul_int4_mma_scheduler if in_dtype == "int4" else self.matmul_mma_scheduler
else:
return self.matmul_simt_scheduler

Expand Down Expand Up @@ -131,7 +134,10 @@ def dispatch_volta_scheduler(self, arch: TileDevice) -> BaseScheduler:
return self.matmul_simt_scheduler

def dispatch_scheduler(self, arch: TileDevice) -> BaseScheduler:
if is_ampere_arch(arch):
if is_hopper_arch(arch):
logger.warning("Hopper architecture is not fully supported yet, fallback to Ada")
return self.dispatch_ampere_scheduler(arch)
elif is_ampere_arch(arch) or is_ada_arch(arch):
return self.dispatch_ampere_scheduler(arch)
elif is_volta_arch(arch):
return self.dispatch_volta_scheduler(arch)
Expand All @@ -143,10 +149,10 @@ def detect_scheduler_from_hint(self, hint: BaseTLHint) -> BaseScheduler:
self.gemv_scheduler,
self.matmul_simt_scheduler,
self.matmul_block_scheduler,
self.matmul_fine_grain_scheduler,
self.matmul_weight_propagation_scheduler,
self.matmul_int4_fine_grain_scheduler,
self.matmul_int4_weight_propagation_scheduler,
self.matmul_mma_scheduler,
self.matmul_mma_weight_propagation_scheduler,
self.matmul_int4_mma_scheduler,
self.matmul_int4_mma_weight_propagation_scheduler,
]:
try:
scheduler_hint_type = scheduler.get_hint_type()
Expand Down Expand Up @@ -213,10 +219,10 @@ def set_dynamic_range(self, dynamic_range: Dict[str, int]) -> "BaseScheduler":
self.gemv_scheduler,
self.matmul_simt_scheduler,
self.matmul_block_scheduler,
self.matmul_fine_grain_scheduler,
self.matmul_weight_propagation_scheduler,
self.matmul_int4_fine_grain_scheduler,
self.matmul_int4_weight_propagation_scheduler,
self.matmul_mma_scheduler,
self.matmul_mma_weight_propagation_scheduler,
self.matmul_int4_mma_scheduler,
self.matmul_int4_mma_weight_propagation_scheduler,
]:
scheduler.set_dynamic_range(dynamic_range)
return self
Expand All @@ -227,10 +233,10 @@ def with_arch(self, arch):
self.gemv_scheduler,
self.matmul_simt_scheduler,
self.matmul_block_scheduler,
self.matmul_fine_grain_scheduler,
self.matmul_weight_propagation_scheduler,
self.matmul_int4_fine_grain_scheduler,
self.matmul_int4_weight_propagation_scheduler,
self.matmul_mma_scheduler,
self.matmul_mma_weight_propagation_scheduler,
self.matmul_int4_mma_scheduler,
self.matmul_int4_mma_weight_propagation_scheduler,
]:
scheduler.with_arch(arch)
return self
Expand Down
2 changes: 1 addition & 1 deletion bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tvm import DataType
import tilelang.language as T
from typing import Optional, List
from bitblas.tl.utils import (
from tilelang.intrinsics.utils import (
get_mma_micro_size,
make_mma_swizzle_layout as make_swizzle_layout,
)
Expand Down
2 changes: 1 addition & 1 deletion bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from bitblas import tilelang as tilelang
import tilelang.language as T
from typing import Optional, List
from bitblas.tl.utils import (
from tilelang.intrinsics.utils import (
get_mma_micro_size,
make_mma_swizzle_layout as make_swizzle_layout,
)
Expand Down
Loading