diff --git a/3rdparty/tilelang b/3rdparty/tilelang index e3b1856dd..6aef1f896 160000 --- a/3rdparty/tilelang +++ b/3rdparty/tilelang @@ -1 +1 @@ -Subproject commit e3b1856dd90947cc4992b5cab6537fa87ecb835e +Subproject commit 6aef1f8968bb3f8f806b74eb1334eb2a44a9ab3a diff --git a/bitblas/builder/wrapper/tl.py b/bitblas/builder/wrapper/tl.py index 85a75601a..c64ca2b8d 100644 --- a/bitblas/builder/wrapper/tl.py +++ b/bitblas/builder/wrapper/tl.py @@ -18,10 +18,10 @@ class TLCUDASourceWrapper(object): _TYPE_MAP = { "float32": "float", - "float16": "half_t", - "bfloat16": "bfloat16_t", - "e4m3_float8": "float_e4m3_t", - "e5m2_float8": "float_e5m2_t", + "float16": "half", + "bfloat16": "__nv_bfloat16", + "e4m3_float8": "__nv_fp8_e4m3", + "e5m2_float8": "__nv_fp8_e5m2", "float64": "double", "int64": "int64_t", "int32": "int", diff --git a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py index ce19f7c80..2f9eaa7c8 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py @@ -158,7 +158,8 @@ def main( ) else: for ki in T.serial(micro_size_k): - accum_res[0] += A_local[ki] * B_local[ki] + accum_res[0] += A_local[ki].astype(accum_dtype) * B_local[ki].astype( + accum_dtype) with T.attr( T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), diff --git a/testing/python/operators/test_general_matmul_fp8.py b/testing/python/operators/test_general_matmul_fp8.py index 0ef2e64af..0cdf2d545 100644 --- a/testing/python/operators/test_general_matmul_fp8.py +++ b/testing/python/operators/test_general_matmul_fp8.py @@ -27,7 +27,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo with_zeros=with_zeros, zeros_mode=zeros_mode, ) - matmul = Matmul(config=matmul_config, enable_tuning=True, backend="tir") + matmul = Matmul(config=matmul_config, enable_tuning=True, backend="tl") input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) @@ -93,7 +93,7 @@ def matmul_torch_forward_weight_dequantize(M, N, K, A_dtype, W_dtype, accum_dtyp propagate_a=False, propagate_b=False, ) - matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tir") + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tl") input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N)