Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions bitblas/ops/general_matmul/tilelang/dense/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
MatmulFineGrainSIMTScheduler, # noqa: F401
)

from .matmul_tensorcore import (
MatmulBlockScheduler,
MatmulFineGrainScheduler,
MatmulWeightPropagationScheduler,
MatmulINT4FineGrainScheduler,
MatmulINT4WeightPropagationScheduler,
from .matmul_tile import (
MatmulTileLibraryScheduler,)

from .matmul_mma import (
MatmulMMAScheduler,
MatmulMMAWeightPropagationScheduler,
MatmulINT4MMAScheduler,
MatmulINT4MMAWeightPropagationScheduler,
)

from .matmul import MatmulScheduler
Expand Down Expand Up @@ -126,8 +128,8 @@ def is_int4_dtype(dtype):
return dtype == "int4" or dtype == "uint4"

if can_apply_weight_propagation_scheduler(trans_A, trans_B, propagate_a, propagate_b):
Scheduler = MatmulWeightPropagationScheduler if not is_int4_dtype(
in_dtype) else MatmulINT4WeightPropagationScheduler
Scheduler = MatmulMMAWeightPropagationScheduler if not is_int4_dtype(
in_dtype) else MatmulINT4MMAWeightPropagationScheduler
return Scheduler(
M=M,
N=N,
Expand All @@ -140,8 +142,7 @@ def is_int4_dtype(dtype):
with_bias=with_bias,
)
if can_apply_fine_grain_scheduler(trans_A, trans_B, propagate_a, propagate_b):
Scheduler = MatmulFineGrainScheduler if not is_int4_dtype(
in_dtype) else MatmulINT4FineGrainScheduler
Scheduler = MatmulMMAScheduler if not is_int4_dtype(in_dtype) else MatmulINT4MMAScheduler
return Scheduler(
M=M,
N=N,
Expand All @@ -154,7 +155,7 @@ def is_int4_dtype(dtype):
with_bias=with_bias,
)
elif can_apply_block_scheduler(propagate_a, propagate_b):
return MatmulBlockScheduler(
return MatmulTileLibraryScheduler(
M=M,
N=N,
K=K,
Expand Down
34 changes: 18 additions & 16 deletions bitblas/ops/general_matmul/tilelang/dense/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
from .base import MatmulBaseParams
from .gemv_simt import GemvFineGrainSIMTScheduler
from .matmul_simt import MatmulFineGrainSIMTScheduler
from .matmul_tensorcore import (
MatmulBlockScheduler,
MatmulFineGrainScheduler,
MatmulWeightPropagationScheduler,
MatmulINT4FineGrainScheduler,
MatmulINT4WeightPropagationScheduler,
from .matmul_tile import (
MatmulTileLibraryScheduler,)
from .matmul_mma import (
MatmulMMAScheduler,
MatmulMMAWeightPropagationScheduler,
MatmulINT4MMAScheduler,
MatmulINT4MMAWeightPropagationScheduler,
)

import logging
Expand All @@ -37,20 +38,21 @@ class MatmulScheduler(MatmulBaseParams):

gemv_scheduler: Optional[GemvFineGrainSIMTScheduler] = None
matmul_simt_scheduler: Optional[MatmulFineGrainSIMTScheduler] = None
matmul_block_scheduler: Optional[MatmulBlockScheduler] = None
matmul_fine_grain_scheduler: Optional[MatmulFineGrainScheduler] = None
matmul_weight_propagation_scheduler: Optional[MatmulWeightPropagationScheduler] = None
matmul_int4_fine_grain_scheduler: Optional[MatmulINT4FineGrainScheduler] = None
matmul_int4_weight_propagation_scheduler: Optional[MatmulINT4WeightPropagationScheduler] = None
matmul_block_scheduler: Optional[MatmulTileLibraryScheduler] = None
matmul_fine_grain_scheduler: Optional[MatmulMMAScheduler] = None
matmul_weight_propagation_scheduler: Optional[MatmulMMAWeightPropagationScheduler] = None
matmul_int4_fine_grain_scheduler: Optional[MatmulINT4MMAScheduler] = None
matmul_int4_weight_propagation_scheduler: Optional[
MatmulINT4MMAWeightPropagationScheduler] = None

def __init__(self, **kwargs):
self.gemv_scheduler = GemvFineGrainSIMTScheduler(**kwargs)
self.matmul_simt_scheduler = MatmulFineGrainSIMTScheduler(**kwargs)
self.matmul_block_scheduler = MatmulBlockScheduler(**kwargs)
self.matmul_fine_grain_scheduler = MatmulFineGrainScheduler(**kwargs)
self.matmul_weight_propagation_scheduler = MatmulWeightPropagationScheduler(**kwargs)
self.matmul_int4_fine_grain_scheduler = MatmulINT4FineGrainScheduler(**kwargs)
self.matmul_int4_weight_propagation_scheduler = MatmulINT4WeightPropagationScheduler(
self.matmul_block_scheduler = MatmulTileLibraryScheduler(**kwargs)
self.matmul_fine_grain_scheduler = MatmulMMAScheduler(**kwargs)
self.matmul_weight_propagation_scheduler = MatmulMMAWeightPropagationScheduler(**kwargs)
self.matmul_int4_fine_grain_scheduler = MatmulINT4MMAScheduler(**kwargs)
self.matmul_int4_weight_propagation_scheduler = MatmulINT4MMAWeightPropagationScheduler(
**kwargs)
super().__init__(**kwargs)

Expand Down
Loading
Loading