Skip to content
This repository was archived by the owner on Feb 24, 2026. It is now read-only.

Commit 08b56e8

Browse files
authored
[Dev] Migrate default backend from tir into tilelang (#270)
* fix for relax * lint fix * save import bitblas time * bug fix for tl backend * support input transform_kind * hint identifier * annotate hint type for dequantize * enhance swizzling * Enhance for hardware aware tuning * test fix * remove pad factor * introduce legalize dyanmic pass * update 3rdparty * testfix * test code commit * enhance typing and fix test for int4 dequantize gemm * lint fix * TEST FIX * lint fix * Bugfix for bias * lint fix * lint fix * test fix * Implement Bias
1 parent f250ec5 commit 08b56e8

22 files changed

+425
-98
lines changed

bitblas/base/roller/policy/tensorcore.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,9 +328,8 @@ def _score(node, thread): # small is better
328328
# TODO: This is a dummy mul which avoid reusing some shared memory.
329329
# Should be removed in the future.
330330
if td.smem_cost > (self.arch.smem_cap):
331-
debug_message = f"Tile Dict: {td.output_tile} Shared memory exceeds the static capacity," \
332-
" use dynamic shared memory."
333-
logger.debug(debug_message)
331+
# Tile Dict: {td.output_tile} Shared memory exceeds the static capacity
332+
# use dynamic shared memory.
334333
codegen_dict.shared_scope = "shared.dyn"
335334

336335
codegen_dict.shared_scope = "shared.dyn"

bitblas/builder/wrapper/tl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ class TLCUDASourceWrapper(object):
1919
_TYPE_MAP = {
2020
"float32": "float",
2121
"float16": "half_t",
22-
"bfloat16": "__nv_bfloat16",
23-
"e4m3_float8": "__nv_fp8_e4m3",
24-
"e5m2_float8": "__nv_fp8_e5m2",
22+
"bfloat16": "bfloat16_t",
23+
"e4m3_float8": "float_e4m3_t",
24+
"e5m2_float8": "float_e5m2_t",
2525
"float64": "double",
2626
"int64": "int64_t",
2727
"int32": "int",

bitblas/module/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def forward(self, A, output=None):
274274
self.init_params()
275275
args = [A_void, *self.q_params]
276276
if output is None:
277-
output = torch.empty(
277+
output = torch.zeros(
278278
A.shape[:-1] + (self.out_features,),
279279
dtype=getattr(torch, self.bitblas_matmul.out_dtype),
280280
device=A.device)

bitblas/ops/general_matmul/__init__.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .tilelang.dense import select_scheduler as consistent_scheduler
1616
from .tilelang.dequantize import select_scheduler as weight_dequantize_scheduler
1717
from ...base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
18+
from bitblas.utils import retrieve_func_from_module
1819
from bitblas.utils.target_detector import auto_detect_nvidia_target
1920
from dataclasses import dataclass
2021
from ..ladder_permutate import LadderPermutate, LadderPermutateConfig
@@ -350,7 +351,7 @@ def __init__(
350351
target: Optional[Union[str, Target]] = None,
351352
enable_tuning: bool = True,
352353
from_database: bool = False,
353-
backend: str = "tir",
354+
backend: str = "tl",
354355
):
355356
# if from database, we should disable default schedule
356357
# to save compilation time
@@ -383,13 +384,13 @@ def __init__(
383384
if target.kind.name not in ("cuda", "hip"):
384385
raise ValueError("Currently only support cuda and hip target")
385386

386-
self.dispatch_tir(target, from_database, source_format, enable_tuning)
387+
self.dispatch(target, from_database, source_format, enable_tuning)
387388

388-
def dispatch_tir(self,
389-
target: Target,
390-
from_database: bool = False,
391-
source_format: str = "uint",
392-
enable_tuning: bool = True):
389+
def dispatch(self,
390+
target: Target,
391+
from_database: bool = False,
392+
source_format: str = "uint",
393+
enable_tuning: bool = True):
393394

394395
if isinstance(self.M, Tuple):
395396
self.dynamic_range = {"m": self.M}
@@ -638,7 +639,21 @@ def post_process(self, code: str) -> str:
638639
return code
639640

640641
def retrieve_weight_shape(self):
641-
return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape]
642+
prim_func = self.prim_func
643+
644+
# retrieve from tilelang backend
645+
if prim_func is None and self.scheduled_ir_module is not None:
646+
prim_func = retrieve_func_from_module(self.scheduled_ir_module)
647+
648+
if prim_func is None and self.is_tilelang_backend():
649+
# If from_database and from tilelang backend, we should construct a default module
650+
self._update_optimized_mod(self.scheduler_with_default(self.scheduler))
651+
prim_func = retrieve_func_from_module(self.scheduled_ir_module)
652+
653+
if prim_func is not None:
654+
return [int(i) for i in prim_func.buffer_map[prim_func.params[1]].shape]
655+
656+
raise ValueError("The weight shape is not available.")
642657

643658
def transform_weight(self, weight, scale=None, zeros=None, bias=None):
644659
"""

bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def apply_config(
111111

112112
A_shape = (M, K)
113113
B_shape = (N, K)
114-
C_shape = (M, N)
115114
Bias_shape = (N,)
115+
C_shape = (M, N)
116116

117117
dp4a_size = 4
118118
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
@@ -121,8 +121,8 @@ def apply_config(
121121
def main(
122122
A: T.Buffer(A_shape, in_dtype),
123123
B: T.Buffer(B_shape, in_dtype),
124-
C: T.Buffer(C_shape, out_dtype),
125124
Bias: T.Buffer(Bias_shape, out_dtype),
125+
C: T.Buffer(C_shape, out_dtype),
126126
):
127127
with T.Kernel(
128128
T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as (
@@ -186,5 +186,4 @@ def __post_init__(self):
186186
# Validate the matrix transpose settings
187187
assert self.trans_A is False, "Currently only support Matrix A not transposed"
188188
assert self.trans_B is True, "Currently only support Matrix B transposed"
189-
assert self.with_bias is False, "Currently only support without bias"
190189
return

bitblas/ops/general_matmul/tilelang/dense/matmul.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ def __post_init__(self):
242242
# Validate the matrix transpose settings
243243
assert self.trans_A is False, "Currently only support Matrix A not transposed"
244244
assert self.trans_B is True, "Currently only support Matrix B transposed"
245-
assert self.with_bias is False, "Currently only support without bias"
246245
assert self.input_transform_kind == TransformKind.NonTransform, "Currently only support NonTransform for input"
247246

248247
return

bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def check_require_cache(self) -> bool:
5757

5858
conditions: List[bool] = []
5959
conditions.append(False)
60-
# Bias Add should be done in shared memory
60+
# Bias Add should be performed in shared memory
6161
conditions.append(with_bias)
6262
return any(conditions) # Always set to False Currently
6363

@@ -172,6 +172,8 @@ def apply_config(
172172
self.accum_dtype,
173173
)
174174

175+
with_bias = self.with_bias
176+
175177
shared_scope = "shared.dyn"
176178

177179
block_M = block_size_x * thread_row_tiles
@@ -183,6 +185,7 @@ def apply_config(
183185
C_shape = (M, N)
184186
A_shared_shape = (block_M, block_K)
185187
B_shared_shape = (block_N, block_K)
188+
Bias_shape = (N,)
186189

187190
threads = thread_row_tiles * thread_col_tiles
188191
local_size_a = block_M // thread_row_tiles
@@ -198,6 +201,7 @@ def apply_config(
198201
def main(
199202
A: T.Buffer(A_shape, in_dtype),
200203
B: T.Buffer(B_shape, in_dtype),
204+
Bias: T.Buffer(Bias_shape, out_dtype),
201205
C: T.Buffer(C_shape, out_dtype),
202206
):
203207
with T.Kernel(
@@ -249,21 +253,28 @@ def main(
249253
else:
250254
for dp4a_idx in T.serial(dp4a_size):
251255
C_local[i * local_size_b + j] += (
252-
A_local[i, mk * dp4a_size + dp4a_idx] *
253-
B_local[j, mk * dp4a_size + dp4a_idx])
254-
255-
for i, j in T.grid(local_size_a, local_size_b):
256-
C[
257-
by * block_M + warp_m * local_size_a + i,
258-
bx * block_N + warp_n * local_size_b + j,
259-
] = C_local[i * local_size_b + j]
256+
A_local[i,
257+
mk * dp4a_size + dp4a_idx].astype(accum_dtype) *
258+
B_local[j,
259+
mk * dp4a_size + dp4a_idx].astype(accum_dtype))
260+
261+
if with_bias:
262+
for i, j in T.grid(local_size_a, local_size_b):
263+
C_local[i * local_size_b + j] += Bias[bx * block_N + warp_n * local_size_b +
264+
j]
265+
266+
for i in T.serial(local_size_a):
267+
for j in T.vectorized(local_size_b):
268+
C[
269+
by * block_M + warp_m * local_size_a + i,
270+
bx * block_N + warp_n * local_size_b + j,
271+
] = C_local[i * local_size_b + j]
260272

261273
return self.post_process(main)
262274

263275
def __post_init__(self):
264276
# Validate the matrix transpose settings
265277
assert self.trans_A is False, "Currently only support Matrix A not transposed"
266278
assert self.trans_B is True, "Currently only support Matrix B transposed"
267-
assert self.with_bias is False, "Currently only support without bias"
268279

269280
return

bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def check_require_cache(self) -> bool:
6969

7070
conditions: List[bool] = []
7171
conditions.append(False)
72-
# Bias Add should be done in shared memory
72+
# Bias Add should be performed in shared memory
7373
conditions.append(with_bias)
7474
return any(conditions) # Always set to False Currently
7575

@@ -227,8 +227,8 @@ def apply_config(
227227
def main(
228228
A: T.Buffer(A_shape, in_dtype),
229229
B: T.Buffer(B_shape, in_dtype),
230-
C: T.Buffer(C_shape, out_dtype),
231230
Bias: T.Buffer(Bias_shape, out_dtype),
231+
C: T.Buffer(C_shape, out_dtype),
232232
):
233233
with T.Kernel(
234234
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
@@ -444,16 +444,15 @@ def apply_config(
444444
chunk=chunk,
445445
)
446446

447-
# cache_write_required = self.check_require_cache()
448-
cache_write_required = False
447+
cache_write_required = self.check_require_cache()
449448

450449
# Define the main kernel using the generated configuration
451450
@T.prim_func
452451
def main(
453452
A: T.Buffer(A_shape, in_dtype),
454453
B: T.Buffer(B_shape, in_dtype),
455-
C: T.Buffer(C_shape, out_dtype),
456454
Bias: T.Buffer(Bias_shape, out_dtype),
455+
C: T.Buffer(C_shape, out_dtype),
457456
):
458457
# Grid and thread configuration for CUDA kernel
459458
with T.Kernel(
@@ -667,8 +666,8 @@ def apply_config(
667666
def main(
668667
A: T.Buffer(A_shape, in_dtype),
669668
B: T.Buffer(B_shape, in_dtype),
670-
C: T.Buffer(C_shape, out_dtype),
671669
Bias: T.Buffer(Bias_shape, out_dtype),
670+
C: T.Buffer(C_shape, out_dtype),
672671
):
673672
# Grid and thread configuration for CUDA kernel
674673
with T.Kernel(
@@ -867,6 +866,8 @@ def apply_config(
867866
in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype
868867
assert in_dtype == "int4", "Only support int4 input"
869868
assert accum_dtype == "int32", "Only support int32 accumulation"
869+
with_bias = self.with_bias
870+
assert not with_bias, "Currently do not support bias"
870871
storage_dtype = "int8"
871872

872873
# Calculate the micro size per warp using a helper function
@@ -879,6 +880,8 @@ def apply_config(
879880
# Define the shapes of matrices and shared memory buffers
880881
A_shape = (M, K)
881882
B_shape = (N, K)
883+
Bias_shape = (N,)
884+
C_shape = (M, N)
882885
A_shared_shape = (block_M, block_K)
883886
B_shared_shape = (block_N, block_K)
884887
C_shared_shape = (
@@ -918,7 +921,8 @@ def apply_config(
918921
def main(
919922
A: T.Buffer(A_shape, storage_dtype),
920923
B: T.Buffer(B_shape, storage_dtype),
921-
C: T.Buffer((M, N), out_dtype),
924+
Bias: T.Buffer(Bias_shape, out_dtype),
925+
C: T.Buffer(C_shape, out_dtype),
922926
):
923927
# Grid and thread configuration for CUDA kernel
924928
with T.Kernel(

bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ def main(
168168
Scale: T.Buffer(Scale_shape, in_dtype),
169169
Qzeros: T.Buffer(Qzeros_shape, storage_dtype),
170170
Zeros: T.Buffer(Zeros_shape, in_dtype),
171-
C: T.Buffer(C_shape, out_dtype),
172171
Bias: T.Buffer(Bias_shape, in_dtype),
172+
C: T.Buffer(C_shape, out_dtype),
173173
):
174174
with T.Kernel(
175175
T.ceildiv(N, n_partition),

bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,8 +624,8 @@ def general_shared_dequant_matmul(
624624
Scale: T.Buffer(Scale_shape, in_dtype),
625625
Qzeros: T.Buffer(Qzeros_shape, storage_dtype),
626626
Zeros: T.Buffer(Zeros_shape, in_dtype),
627-
C: T.Buffer(C_shape, out_dtype),
628627
Bias: T.Buffer(Bias_shape, in_dtype),
628+
C: T.Buffer(C_shape, out_dtype),
629629
):
630630
with T.Kernel(
631631
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):

0 commit comments

Comments
 (0)