Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,8 @@ def _score(node, thread): # small is better
# TODO: This is a dummy mul which avoid reusing some shared memory.
# Should be removed in the future.
if td.smem_cost > (self.arch.smem_cap):
debug_message = f"Tile Dict: {td.output_tile} Shared memory exceeds the static capacity," \
" use dynamic shared memory."
logger.debug(debug_message)
# Tile Dict: {td.output_tile} Shared memory exceeds the static capacity
# use dynamic shared memory.
codegen_dict.shared_scope = "shared.dyn"

codegen_dict.shared_scope = "shared.dyn"
Expand Down
6 changes: 3 additions & 3 deletions bitblas/builder/wrapper/tl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ class TLCUDASourceWrapper(object):
_TYPE_MAP = {
"float32": "float",
"float16": "half_t",
"bfloat16": "__nv_bfloat16",
"e4m3_float8": "__nv_fp8_e4m3",
"e5m2_float8": "__nv_fp8_e5m2",
"bfloat16": "bfloat16_t",
"e4m3_float8": "float_e4m3_t",
"e5m2_float8": "float_e5m2_t",
"float64": "double",
"int64": "int64_t",
"int32": "int",
Expand Down
2 changes: 1 addition & 1 deletion bitblas/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def forward(self, A, output=None):
self.init_params()
args = [A_void, *self.q_params]
if output is None:
output = torch.empty(
output = torch.zeros(
A.shape[:-1] + (self.out_features,),
dtype=getattr(torch, self.bitblas_matmul.out_dtype),
device=A.device)
Expand Down
31 changes: 23 additions & 8 deletions bitblas/ops/general_matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .tilelang.dense import select_scheduler as consistent_scheduler
from .tilelang.dequantize import select_scheduler as weight_dequantize_scheduler
from ...base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
from bitblas.utils import retrieve_func_from_module
from bitblas.utils.target_detector import auto_detect_nvidia_target
from dataclasses import dataclass
from ..ladder_permutate import LadderPermutate, LadderPermutateConfig
Expand Down Expand Up @@ -350,7 +351,7 @@ def __init__(
target: Optional[Union[str, Target]] = None,
enable_tuning: bool = True,
from_database: bool = False,
backend: str = "tir",
backend: str = "tl",
):
# if from database, we should disable default schedule
# to save compilation time
Expand Down Expand Up @@ -383,13 +384,13 @@ def __init__(
if target.kind.name not in ("cuda", "hip"):
raise ValueError("Currently only support cuda and hip target")

self.dispatch_tir(target, from_database, source_format, enable_tuning)
self.dispatch(target, from_database, source_format, enable_tuning)

def dispatch_tir(self,
target: Target,
from_database: bool = False,
source_format: str = "uint",
enable_tuning: bool = True):
def dispatch(self,
target: Target,
from_database: bool = False,
source_format: str = "uint",
enable_tuning: bool = True):

if isinstance(self.M, Tuple):
self.dynamic_range = {"m": self.M}
Expand Down Expand Up @@ -638,7 +639,21 @@ def post_process(self, code: str) -> str:
return code

def retrieve_weight_shape(self):
return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape]
prim_func = self.prim_func

# retrieve from tilelang backend
if prim_func is None and self.scheduled_ir_module is not None:
prim_func = retrieve_func_from_module(self.scheduled_ir_module)

if prim_func is None and self.is_tilelang_backend():
# If from_database and from tilelang backend, we should construct a default module
self._update_optimized_mod(self.scheduler_with_default(self.scheduler))
prim_func = retrieve_func_from_module(self.scheduled_ir_module)

if prim_func is not None:
return [int(i) for i in prim_func.buffer_map[prim_func.params[1]].shape]

raise ValueError("The weight shape is not available.")

def transform_weight(self, weight, scale=None, zeros=None, bias=None):
"""
Expand Down
5 changes: 2 additions & 3 deletions bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def apply_config(

A_shape = (M, K)
B_shape = (N, K)
C_shape = (M, N)
Bias_shape = (N,)
C_shape = (M, N)

dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
Expand All @@ -121,8 +121,8 @@ def apply_config(
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
Bias: T.Buffer(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, n_partition), M, threads=(reduce_thread, n_partition)) as (
Expand Down Expand Up @@ -186,5 +186,4 @@ def __post_init__(self):
# Validate the matrix transpose settings
assert self.trans_A is False, "Currently only support Matrix A not transposed"
assert self.trans_B is True, "Currently only support Matrix B transposed"
assert self.with_bias is False, "Currently only support without bias"
return
1 change: 0 additions & 1 deletion bitblas/ops/general_matmul/tilelang/dense/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def __post_init__(self):
# Validate the matrix transpose settings
assert self.trans_A is False, "Currently only support Matrix A not transposed"
assert self.trans_B is True, "Currently only support Matrix B transposed"
assert self.with_bias is False, "Currently only support without bias"
assert self.input_transform_kind == TransformKind.NonTransform, "Currently only support NonTransform for input"

return
Expand Down
31 changes: 21 additions & 10 deletions bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def check_require_cache(self) -> bool:

conditions: List[bool] = []
conditions.append(False)
# Bias Add should be done in shared memory
# Bias Add should be performed in shared memory
conditions.append(with_bias)
return any(conditions) # Always set to False Currently

Expand Down Expand Up @@ -172,6 +172,8 @@ def apply_config(
self.accum_dtype,
)

with_bias = self.with_bias

shared_scope = "shared.dyn"

block_M = block_size_x * thread_row_tiles
Expand All @@ -183,6 +185,7 @@ def apply_config(
C_shape = (M, N)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
Bias_shape = (N,)

threads = thread_row_tiles * thread_col_tiles
local_size_a = block_M // thread_row_tiles
Expand All @@ -198,6 +201,7 @@ def apply_config(
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
Bias: T.Buffer(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
Expand Down Expand Up @@ -249,21 +253,28 @@ def main(
else:
for dp4a_idx in T.serial(dp4a_size):
C_local[i * local_size_b + j] += (
A_local[i, mk * dp4a_size + dp4a_idx] *
B_local[j, mk * dp4a_size + dp4a_idx])

for i, j in T.grid(local_size_a, local_size_b):
C[
by * block_M + warp_m * local_size_a + i,
bx * block_N + warp_n * local_size_b + j,
] = C_local[i * local_size_b + j]
A_local[i,
mk * dp4a_size + dp4a_idx].astype(accum_dtype) *
B_local[j,
mk * dp4a_size + dp4a_idx].astype(accum_dtype))

if with_bias:
for i, j in T.grid(local_size_a, local_size_b):
C_local[i * local_size_b + j] += Bias[bx * block_N + warp_n * local_size_b +
j]

for i in T.serial(local_size_a):
for j in T.vectorized(local_size_b):
C[
by * block_M + warp_m * local_size_a + i,
bx * block_N + warp_n * local_size_b + j,
] = C_local[i * local_size_b + j]

return self.post_process(main)

def __post_init__(self):
# Validate the matrix transpose settings
assert self.trans_A is False, "Currently only support Matrix A not transposed"
assert self.trans_B is True, "Currently only support Matrix B transposed"
assert self.with_bias is False, "Currently only support without bias"

return
18 changes: 11 additions & 7 deletions bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def check_require_cache(self) -> bool:

conditions: List[bool] = []
conditions.append(False)
# Bias Add should be done in shared memory
# Bias Add should be performed in shared memory
conditions.append(with_bias)
return any(conditions) # Always set to False Currently

Expand Down Expand Up @@ -227,8 +227,8 @@ def apply_config(
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
Bias: T.Buffer(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
Expand Down Expand Up @@ -444,16 +444,15 @@ def apply_config(
chunk=chunk,
)

# cache_write_required = self.check_require_cache()
cache_write_required = False
cache_write_required = self.check_require_cache()

# Define the main kernel using the generated configuration
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
Bias: T.Buffer(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype),
):
# Grid and thread configuration for CUDA kernel
with T.Kernel(
Expand Down Expand Up @@ -667,8 +666,8 @@ def apply_config(
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
Bias: T.Buffer(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype),
):
# Grid and thread configuration for CUDA kernel
with T.Kernel(
Expand Down Expand Up @@ -867,6 +866,8 @@ def apply_config(
in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype
assert in_dtype == "int4", "Only support int4 input"
assert accum_dtype == "int32", "Only support int32 accumulation"
with_bias = self.with_bias
assert not with_bias, "Currently do not support bias"
storage_dtype = "int8"

# Calculate the micro size per warp using a helper function
Expand All @@ -879,6 +880,8 @@ def apply_config(
# Define the shapes of matrices and shared memory buffers
A_shape = (M, K)
B_shape = (N, K)
Bias_shape = (N,)
C_shape = (M, N)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
Expand Down Expand Up @@ -918,7 +921,8 @@ def apply_config(
def main(
A: T.Buffer(A_shape, storage_dtype),
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((M, N), out_dtype),
Bias: T.Buffer(Bias_shape, out_dtype),
C: T.Buffer(C_shape, out_dtype),
):
# Grid and thread configuration for CUDA kernel
with T.Kernel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ def main(
Scale: T.Buffer(Scale_shape, in_dtype),
Qzeros: T.Buffer(Qzeros_shape, storage_dtype),
Zeros: T.Buffer(Zeros_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
Bias: T.Buffer(Bias_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, n_partition),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,8 @@ def general_shared_dequant_matmul(
Scale: T.Buffer(Scale_shape, in_dtype),
Qzeros: T.Buffer(Qzeros_shape, storage_dtype),
Zeros: T.Buffer(Zeros_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
Bias: T.Buffer(Bias_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,8 @@ def general_shared_dequant_matmul(
Scale: T.Buffer(Scale_shape, in_dtype),
Qzeros: T.Buffer(Qzeros_shape, storage_dtype),
Zeros: T.Buffer(Zeros_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
Bias: T.Buffer(Bias_shape, in_dtype),
C: T.Buffer(C_shape, out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
Expand Down
Loading
Loading