diff --git a/.gitmodules b/.gitmodules index d5b545545..adbfcc33f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,8 +1,12 @@ [submodule "3rdparty/tvm"] path = 3rdparty/tvm - url = https://github.com/TileLang/tvm.git - branch = upstream + url = https://github.com/tile-ai/tvm.git + branch = tilelang_codebase +[submodule "3rdparty/tilelang"] + path = 3rdparty/tilelang + url = https://github.com/tile-ai/tilelang + branch = bitblas [submodule "3rdparty/cutlass"] path = 3rdparty/cutlass - url = https://github.com/TileLang/cutlass + url = https://github.com/tile-ai/cutlass branch = tldev diff --git a/3rdparty/tilelang b/3rdparty/tilelang new file mode 160000 index 000000000..e3b1856dd --- /dev/null +++ b/3rdparty/tilelang @@ -0,0 +1 @@ +Subproject commit e3b1856dd90947cc4992b5cab6537fa87ecb835e diff --git a/3rdparty/tvm b/3rdparty/tvm index 41edb06ed..b372d9ca2 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 41edb06ed039944978c671afbd2dde5f22667c83 +Subproject commit b372d9ca2159a1afd5439990f68bfa29578a8bac diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 32ff07132..b6f4bdb35 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -90,7 +90,7 @@ def new_func(*args, **kwargs): if TVM_IMPORT_PYTHON_PATH is not None: os.environ["PYTHONPATH"] = (TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "")) - sys.path.insert(0, TVM_IMPORT_PYTHON_PATH + "/python") + sys.path.insert(0, TVM_IMPORT_PYTHON_PATH) else: # remove the existing tvm path in PYTHONPATH def remove_tvm_path(path): @@ -107,6 +107,7 @@ def remove_tvm_path(path): os.environ["PYTHONPATH"] = ( install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")) sys.path.insert(0, install_tvm_path + "/python") + os.environ["TVM_IMPORT_PYTHON_PATH"] = install_tvm_path + "/python" # developed 3rdparty tvm develop_tvm_path = os.path.join( @@ -119,6 +120,22 @@ def remove_tvm_path(path): os.environ["PYTHONPATH"] = ( develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")) sys.path.insert(0, develop_tvm_path + "/python") + os.environ["TVM_IMPORT_PYTHON_PATH"] = develop_tvm_path + "/python" + +# TILELANG PATH +if os.environ.get("TILELANG_IMPORT_PATH", None) is None: + install_tilelang_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tilelang") + develop_tilelang_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tilelang") + if os.path.exists(install_tilelang_path): + os.environ["PYTHONPATH"] = install_tilelang_path + ":" + os.environ.get("PYTHONPATH", "") + sys.path.insert(0, install_tilelang_path) + elif (os.path.exists(develop_tilelang_path) and develop_tilelang_path not in sys.path): + os.environ["PYTHONPATH"] = develop_tilelang_path + ":" + os.environ.get("PYTHONPATH", "") + sys.path.insert(0, develop_tilelang_path) + else: + logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE) if os.environ.get("TL_CUTLASS_PATH", None) is None: install_cutlass_path = os.path.join( @@ -133,6 +150,8 @@ def remove_tvm_path(path): logger.warning(CUTLASS_NOT_FOUND_MESSAGE) import tvm as tvm # noqa: E402 +import tilelang as tilelang # noqa: E402 + from .base import ( TileDevice, # noqa: F401 fast_tune, # noqa: F401 diff --git a/bitblas/base/base_scheduler.py b/bitblas/base/base_scheduler.py index d901a4192..06b605640 100644 --- a/bitblas/base/base_scheduler.py +++ b/bitblas/base/base_scheduler.py @@ -1,9 +1,13 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from tvm import te from tvm import IRModule from tvm.tir import PrimFunc from typing import Optional, Union, Callable, List, Dict from dataclasses import dataclass, field -from tvm.tl.transform import Simplify +from tilelang.transform import Simplify from abc import ABC, abstractmethod from bitblas.base.arch import TileDevice, is_volta_arch, is_ampere_arch, is_cdna_arch, auto_infer_current_arch from bitblas.base.roller.hint import Hint diff --git a/bitblas/builder/lib_generator/__init__.py b/bitblas/builder/lib_generator/__init__.py index 642198060..ba4106259 100644 --- a/bitblas/builder/lib_generator/__init__.py +++ b/bitblas/builder/lib_generator/__init__.py @@ -67,24 +67,24 @@ def compile_lib(self, timeout: float = None, with_tl: bool = False): raise ValueError(f"Unsupported platform: {platform}") if with_tl: - install_tvm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "../..", "3rdparty", "tvm") - develop_tvm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "../../..", "3rdparty", "tvm") + install_tilelang_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../..", "3rdparty", "tilelang") + develop_tilelang_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../../..", "3rdparty", "tilelang") - tvm_root = next((path for path in [install_tvm_path, develop_tvm_path] - if os.path.exists(path) and path not in sys.path), None) + tilelang_root = next((path for path in [install_tilelang_path, develop_tilelang_path] + if os.path.exists(path) and path not in sys.path), None) if "TL_TEMPLATE_PATH " in os.environ: tl_template_path = os.environ["TL_TEMPLATE_PATH"] else: - tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl")) + tl_template_path = osp.abspath(osp.join(tilelang_root, "src")) - tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl")) + tl_template_path = osp.abspath(osp.join(tilelang_root, "src")) if "TL_CUTLASS_PATH" in os.environ: cutlass_path = os.environ["TL_CUTLASS_PATH"] else: - cutlass_path = osp.abspath(osp.join(tvm_root, "3rdparty/cutlass/include")) + cutlass_path = osp.abspath(osp.join(tilelang_root, "3rdparty/cutlass/include")) command += [ "-I" + tl_template_path, diff --git a/bitblas/gpu/intrin/hip.py b/bitblas/gpu/intrin/hip.py index 9883eaed1..4ac668e4d 100644 --- a/bitblas/gpu/intrin/hip.py +++ b/bitblas/gpu/intrin/hip.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from tvm.runtime import convert from tvm.tir.expr import Cast, IntImm from tvm.tir.function import TensorIntrin diff --git a/bitblas/ops/general_flashatten/tilelang/flashatten.py b/bitblas/ops/general_flashatten/tilelang/flashatten.py index d2a5b2857..81e00076f 100644 --- a/bitblas/ops/general_flashatten/tilelang/flashatten.py +++ b/bitblas/ops/general_flashatten/tilelang/flashatten.py @@ -3,7 +3,8 @@ from bitblas import tvm as tvm from bitblas.base.base_scheduler import BaseScheduler -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from dataclasses import dataclass from typing import Optional import logging diff --git a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py index 5891acb14..ce19f7c80 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/gemv_simt.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from functools import reduce from typing import Optional, List -import tvm.tl.language as T +import tilelang.language as T from tvm import DataType from tvm.tir import PrimFunc diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py index c28ebee9e..f36c05663 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_mma.py @@ -3,8 +3,9 @@ # tile represents tile library from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from tvm import DataType -import tvm.tl.language as T +import tilelang.language as T from typing import Optional, List from bitblas.tl.utils import ( get_mma_micro_size, diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py index 5c44daae3..7d42bb1c9 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from typing import Optional, List -import tvm.tl.language as T +import tilelang.language as T from tvm import DataType from tvm.tir import PrimFunc diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py index 8dceefd73..a906bc308 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tile.py @@ -3,7 +3,8 @@ # tile represents tile library from bitblas import tvm as tvm -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from typing import Optional, List from bitblas.tl.utils import ( get_mma_micro_size, diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py b/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py index 4f0f3b0c1..bf1d59081 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/gemv_dequantize_simt.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from functools import reduce from typing import Optional, List -import tvm.tl.language as T +import tilelang.language as T from tvm import DataType from tvm.tir import PrimFunc diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py index f651d8c0b..aea3d331e 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma.py @@ -2,7 +2,8 @@ # Licensed under the MIT License. from bitblas import tvm as tvm from tvm import DataType -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from typing import Optional, List from bitblas.tl.utils import ( get_mma_micro_size, # noqa: F401 diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py index 6e24ab2da..67330730d 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_mma_weight_transform.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from tvm import DataType -import tvm.tl.language as T +import tilelang.language as T from typing import Optional, List from bitblas.tl.utils import ( get_mma_micro_size, # noqa: F401 @@ -680,7 +681,8 @@ def is_b_smooth(self): @dataclass -class MatmulINT4DequantizeMMAWeightPropagationScheduler(MatmulDequantizeMMAWeightPropagationScheduler): +class MatmulINT4DequantizeMMAWeightPropagationScheduler( + MatmulDequantizeMMAWeightPropagationScheduler): class TLHint(MatmulDequantizeMMAWeightPropagationScheduler.TLHint): hint_type: str = "MatmulINT4DequantizeMMAWeightPropagationScheduler" diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py index 0fda0b2ad..b903f5140 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py @@ -3,7 +3,8 @@ from bitblas import tvm as tvm from tvm import DataType from tvm.tir import PrimFunc -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from typing import Optional, List from bitblas.base.arch import TileDevice from bitblas.base.roller.hint import Hint diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tile.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tile.py index c605ad9d9..06ab8ee7c 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tile.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tile.py @@ -4,7 +4,8 @@ from bitblas import tvm as tvm from tvm import DataType -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from typing import Optional, List from bitblas.base.arch import TileDevice from bitblas.base.roller.hint import Hint diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index ea99490c5..304723eb8 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from abc import ABC, abstractmethod -from bitblas import tvm -from tvm import tl +from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from tvm import IRModule from tvm.runtime.module import Module from tvm.target import Target @@ -192,7 +192,7 @@ def tvm_callback_hip_postproc(code, _): if self.is_tir_backend(): rt_mod = tvm.build(self.scheduled_ir_module, target=target) elif self.is_tilelang_backend(): - rt_mod = tl.lower( + rt_mod = tilelang.lower( self.scheduled_ir_module, target=target, runtime_only=True) else: raise ValueError(f"Unsupported backend: {self.backend}") diff --git a/bitblas/tl/mfma_layout.py b/bitblas/tl/mfma_layout.py index 79e75e438..af0d3a47b 100644 --- a/bitblas/tl/mfma_layout.py +++ b/bitblas/tl/mfma_layout.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - +from bitblas import tilelang as tilelang from tvm import DataType -import tvm.tl.language as T +import tilelang.language as T from tvm.runtime import convert diff --git a/bitblas/tl/mfma_macro_generator.py b/bitblas/tl/mfma_macro_generator.py index 0148cd8bf..3a0cc2582 100644 --- a/bitblas/tl/mfma_macro_generator.py +++ b/bitblas/tl/mfma_macro_generator.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from typing import Tuple from tvm import DataType from tvm.tir import PrimExpr diff --git a/bitblas/tl/mma_layout.py b/bitblas/tl/mma_layout.py index aad3ff955..c091b756a 100644 --- a/bitblas/tl/mma_layout.py +++ b/bitblas/tl/mma_layout.py @@ -3,7 +3,8 @@ from typing import Union from tvm import arith, DataType -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): diff --git a/bitblas/tl/mma_macro_generator.py b/bitblas/tl/mma_macro_generator.py index f28233911..92169cbc8 100644 --- a/bitblas/tl/mma_macro_generator.py +++ b/bitblas/tl/mma_macro_generator.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from typing import Union, Tuple, Optional from bitblas.base.operator_common import TransformKind from tvm import DataType diff --git a/bitblas/tl/tuner.py b/bitblas/tl/tuner.py index a275e91f6..0c6eded9c 100644 --- a/bitblas/tl/tuner.py +++ b/bitblas/tl/tuner.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from bitblas import tvm +from bitblas import tvm as tvm +from bitblas import tilelang as tilelang import os import logging import tempfile @@ -10,7 +11,6 @@ from tvm import IRModule from tvm.runtime import Module from tvm.tir import Schedule -import tvm.tl as tl from bitblas.tl.base_hint import BaseTLHint from bitblas.base.arch import TileDevice from bitblas.base.utils import get_dummy_input_arrays @@ -122,7 +122,7 @@ def tvm_callback_cuda_postproc(code, _): "tir.disable_cse_tir": True, **(config.pass_context if config.pass_context else {}) }): - rt_mod = tl.lower(tl_prim_func, arch.target, runtime_only=True) + rt_mod = tilelang.lower(tl_prim_func, arch.target, runtime_only=True) from tvm.contrib.tar import tar # Import the tar module diff --git a/bitblas/tl/wmma_macro_generator.py b/bitblas/tl/wmma_macro_generator.py index 0b81c1b04..48a42bd1a 100644 --- a/bitblas/tl/wmma_macro_generator.py +++ b/bitblas/tl/wmma_macro_generator.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from typing import Tuple, Optional from tvm import DataType from tvm.tir import PrimExpr diff --git a/bitblas/utils/rtmod_analysis.py b/bitblas/utils/rtmod_analysis.py index e3fe4c1cb..87b3c24c3 100644 --- a/bitblas/utils/rtmod_analysis.py +++ b/bitblas/utils/rtmod_analysis.py @@ -7,8 +7,8 @@ from tvm.target import Target from typing import Tuple, List from tvm import tir -from tvm import tl -from tvm.tl.engine import is_device_call +from bitblas import tilelang as tilelang +from tilelang.engine import is_device_call def get_annotated_device_mod_from_tl(mod: IRModule, target: Target) -> "IRModule": @@ -16,18 +16,18 @@ def get_annotated_device_mod_from_tl(mod: IRModule, target: Target) -> "IRModule target = tvm.target.Target(target, target_host) mod = tir.transform.BindTarget(target)(mod) - mod = tl.transform.FrontendLegalize()(mod) + mod = tilelang.transform.FrontendLegalize()(mod) mod = tir.transform.Simplify()(mod) - mod = tl.transform.LayoutInference()(mod) - mod = tl.transform.LowerTileOp()(mod) + mod = tilelang.transform.LayoutInference()(mod) + mod = tilelang.transform.LowerTileOp()(mod) mod = tir.transform.Simplify()(mod) if target.arch == "sm_90": - mod = tl.transform.WarpSpecializedPipeline()(mod) + mod = tilelang.transform.WarpSpecializedPipeline()(mod) else: mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) - mod = tl.transform.PipelinePlanning()(mod) - mod = tl.transform.InjectSoftwarePipeline()(mod) + mod = tilelang.transform.PipelinePlanning()(mod) + mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tir.transform.LowerOpaqueBlock()(mod) mod = tir.transform.FlattenBuffer()(mod) @@ -57,7 +57,7 @@ def get_annotated_device_mod_from_tl(mod: IRModule, target: Target) -> "IRModule # the Legalization. mod = tir.transform.LowerThreadAllreduce()(mod) mod = tir.transform.ThreadSync("shared.dyn")(mod) - mod = tl.transform.LowerHopperIntrin()(mod) + mod = tilelang.transform.LowerHopperIntrin()(mod) mod = tir.transform.InjectPTXAsyncCopy()(mod) mod = tir.transform.AnnotateDeviceRegions()(mod) diff --git a/install.sh b/install.sh index 49d1fa815..77c258706 100755 --- a/install.sh +++ b/install.sh @@ -119,22 +119,70 @@ else echo "TVM build completed successfully." fi -cd ../../.. +TVM_PREBUILD_PATH=$(realpath .) -# Step 11: Set environment variables -echo "Configuring environment variables for TVM..." -echo "export TVM_HOME=$(pwd)/3rdparty/tvm" >> ~/.bashrc -echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd):\$PYTHONPATH" >> ~/.bashrc -echo "export CUDA_DEVICE_ORDER=PCI_BUS_ID" >> ~/.bashrc +cd ../.. -# Step 12: Source .bashrc to apply changes -echo "Applying environment changes by sourcing .bashrc..." -source ~/.bashrc +echo "Building TileLang with CMake..." +cd tilelang +mkdir build +cd build + +cmake .. -DTVM_PREBUILD_PATH=$TVM_PREBUILD_PATH if [ $? -ne 0 ]; then - echo "Error: Failed to source .bashrc." + echo "Error: CMake configuration failed." exit 1 +fi + +make -j +if [ $? -ne 0 ]; then + echo "Error: TileLang build failed." + exit 1 +else + echo "TileLang build completed successfully." +fi + +echo "TileLang build completed successfully." + +cd ../../.. + +# Set environment variables +TVM_HOME_ENV="export TVM_HOME=$(pwd)/3rdparty/tvm" +TVM_EXPORT_ENV="export TVM_IMPORT_PYTHON_PATH=/root/BitBLAS/3rdparty/tvm/python" +TILELANG_HOME_ENV="export TILELANG_HOME=$(pwd)/3rdparty/tilelang" +BITBLAS_PYPATH_ENV="export PYTHONPATH=\$TVM_HOME/python:\$TILELANG_HOME:$(pwd):\$PYTHONPATH" +CUDA_DEVICE_ORDER_ENV="export CUDA_DEVICE_ORDER=PCI_BUS_ID" + +# Inject break line if the last line of the file is not empty +if [ -s ~/.bashrc ]; then + if [ "$(tail -c 1 ~/.bashrc)" != "" ]; then + echo "" >> ~/.bashrc + fi +fi + +# Check and add the first line if not already present +if ! grep -qxF "$TVM_HOME_ENV" ~/.bashrc; then + echo "$TVM_HOME_ENV" >> ~/.bashrc + echo "Added TVM_HOME to ~/.bashrc" else - echo "Environment configured successfully." + echo "TVM_HOME is already set in ~/.bashrc" fi -echo "Installation script completed successfully." +# Check and add the second line if not already present +if ! grep -qxF "$BITBLAS_PYPATH_ENV" ~/.bashrc; then + echo "$BITBLAS_PYPATH_ENV" >> ~/.bashrc + echo "Added PYTHONPATH to ~/.bashrc" +else + echo "PYTHONPATH is already set in ~/.bashrc" +fi + +# Check and add the third line if not already present +if ! grep -qxF "$CUDA_DEVICE_ORDER_ENV" ~/.bashrc; then + echo "$CUDA_DEVICE_ORDER_ENV" >> ~/.bashrc + echo "Added CUDA_DEVICE_ORDER to ~/.bashrc" +else + echo "CUDA_DEVICE_ORDER is already set in ~/.bashrc" +fi + +# Reload ~/.bashrc to apply the changes +source ~/.bashrc diff --git a/install_amd.sh b/install_amd.sh index f64e442a0..dec3dedcf 100755 --- a/install_amd.sh +++ b/install_amd.sh @@ -62,11 +62,46 @@ echo "set(USE_LLVM llvm-config-16)" >> config.cmake && echo "set(USE_ROCM /opt/r cmake .. && make -j && cd ../../.. +TVM_PREBUILD_PATH=$(realpath .) + +cd ../.. + +echo "Building TileLang with CMake..." +cd tilelang +mkdir build +cd build + +cmake .. -DTVM_PREBUILD_PATH=$TVM_PREBUILD_PATH +if [ $? -ne 0 ]; then + echo "Error: CMake configuration failed." + exit 1 +fi + +make -j +if [ $? -ne 0 ]; then + echo "Error: TileLang build failed." + exit 1 +else + echo "TileLang build completed successfully." +fi + +echo "TileLang build completed successfully." + +cd ../../.. + # Define the lines to be added TVM_HOME_ENV="export TVM_HOME=$(pwd)/3rdparty/tvm" -BITBLAS_PYPATH_ENV="export PYTHONPATH=\$TVM_HOME/python:$(pwd):\$PYTHONPATH" +TILELANG_HOME_ENV="export TILELANG_HOME=$(pwd)/3rdparty/tilelang" +BITBLAS_PYPATH_ENV="export PYTHONPATH=\$TVM_HOME/python:\$TILELANG_HOME:$(pwd):\$PYTHONPATH" CUDA_DEVICE_ORDER_ENV="export CUDA_DEVICE_ORDER=PCI_BUS_ID" +# Inject break line if the last line of the file is not empty +if [ -s ~/.bashrc ]; then + if [ "$(tail -c 1 ~/.bashrc)" != "" ]; then + echo "" >> ~/.bashrc + fi +fi + # Check and add the first line if not already present if ! grep -qxF "$TVM_HOME_ENV" ~/.bashrc; then echo "$TVM_HOME_ENV" >> ~/.bashrc diff --git a/setup.py b/setup.py index 2f8bb4ce4..5244cc187 100644 --- a/setup.py +++ b/setup.py @@ -191,6 +191,27 @@ def build_tvm(llvm_config_path): os.chdir("../../..") +def build_tilelang(TVM_PREBUILD_PATH: str = "./3rdparty/tvm/build"): + """Builds TILELANG.""" + abs_tvm_prebuilt_path = os.path.abspath(TVM_PREBUILD_PATH) + print(f"Using TVM prebuilt path: {abs_tvm_prebuilt_path}") + + os.chdir("3rdparty/tilelang") + if not os.path.exists("build"): + os.makedirs("build") + os.chdir("build") + # Run CMake and make + try: + subprocess.check_call(["cmake", "..", f"-DTVM_PREBUILD_PATH={abs_tvm_prebuilt_path}"]) + num_jobs = multiprocessing.cpu_count() + subprocess.check_call(["make", f"-j{num_jobs}"]) + except subprocess.CalledProcessError as error: + raise RuntimeError("Failed to build TILELANG") from error + finally: + # Go back to the original directory + os.chdir("../../..") + + def setup_llvm_for_tvm(): """Downloads and extracts LLVM, then configures TVM to use it.""" # Assume the download_and_extract_llvm function and its dependencies are defined elsewhere in this script @@ -209,6 +230,8 @@ def run(self): _, llvm_path = setup_llvm_for_tvm() # Build TVM build_tvm(llvm_path) + # Build TILELANG + build_tilelang() # Continue with the standard installation process install.run(self) @@ -224,6 +247,8 @@ def run(self): _, llvm_path = setup_llvm_for_tvm() # Build TVM build_tvm(llvm_path) + # Build TILELANG + build_tilelang() # Copy the built TVM to the package directory TVM_PREBUILD_ITEMS = [ @@ -240,7 +265,6 @@ def run(self): "3rdparty/tvm/mypy.ini", "3rdparty/tvm/pyproject.toml", "3rdparty/tvm/version.py", - "3rdparty/tvm/src/tl/tl_templates", ] for item in TVM_PREBUILD_ITEMS: source_dir = os.path.join(ROOT_DIR, item) @@ -254,6 +278,26 @@ def run(self): os.makedirs(target_dir) shutil.copy2(source_dir, target_dir) + # Copy the built TILELANG to the package directory + TILELANG_PREBUILD_ITEMS = [ + "3rdparty/tilelang/build/libtilelang_module.so", + "3rdparty/tilelang/build/libtilelang.so", + "3rdparty/tilelang/tilelang", + "3rdparty/tilelang/src/tl_templates", + "3rdparty/tilelang/VERSION", + ] + for item in TILELANG_PREBUILD_ITEMS: + source_dir = os.path.join(ROOT_DIR, item) + target_dir = os.path.join(self.build_lib, PACKAGE_NAME, item) + if os.path.isdir(source_dir): + self.mkpath(target_dir) + distutils.dir_util.copy_tree(source_dir, target_dir) + else: + target_dir = os.path.dirname(target_dir) + if not os.path.exists(target_dir): + os.makedirs(target_dir) + shutil.copy2(source_dir, target_dir) + # Copy CUTLASS to the package directory CUTLASS_PREBUILD_ITEMS = [ "3rdparty/cutlass", diff --git a/testing/python/operators/test_general_matmul_ops_backend.py b/testing/python/operators/test_general_matmul_ops_backend.py index d1a2253f3..8d80f7d87 100644 --- a/testing/python/operators/test_general_matmul_ops_backend.py +++ b/testing/python/operators/test_general_matmul_ops_backend.py @@ -34,6 +34,10 @@ def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, la zeros_mode=zeros_mode, ) matmul = Matmul(config=matmul_config, enable_tuning=False) + func = matmul.prim_func + import tilelang + rt_mod, params = tilelang.lower(func) + print(rt_mod) assert get_codegen_result(matmul) diff --git a/testing/python/operators/test_general_matmul_tilelang_impl.py b/testing/python/operators/test_general_matmul_tilelang_impl.py index b05d10390..3f258bab8 100644 --- a/testing/python/operators/test_general_matmul_tilelang_impl.py +++ b/testing/python/operators/test_general_matmul_tilelang_impl.py @@ -2,8 +2,8 @@ # Licensed under the MIT License. from bitblas import tvm as tvm +from bitblas import tilelang as tilelang import bitblas.testing -from tvm import tl from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( matmul_blocked, matmul_macro_tensorcore, @@ -47,7 +47,7 @@ def assert_matmul_blocked_correctness(M, enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -57,7 +57,7 @@ def assert_matmul_blocked_correctness(M, B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -105,7 +105,7 @@ def assert_matmul_macro_tensorcore_correctness( num_stages=num_stages, enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code represents generated cuda source @@ -115,7 +115,7 @@ def assert_matmul_macro_tensorcore_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -164,7 +164,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -185,7 +185,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness( LB = ladder_permutate(B.cpu()).cuda() - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, LB, C) diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index b789bd4e1..dc1f1c424 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -3,7 +3,7 @@ from bitblas import tvm as tvm import bitblas.testing -from tvm import tl +from bitblas import tilelang as tilelang from bitblas.ops.general_matmul.tilelang.dense.matmul_tile import ( MatmulTileLibraryScheduler,) @@ -51,7 +51,7 @@ def assert_matmul_blocked_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -61,7 +61,7 @@ def assert_matmul_blocked_with_default_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -109,7 +109,7 @@ def assert_matmul_blocked_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -119,7 +119,7 @@ def assert_matmul_blocked_apply_config_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -155,7 +155,7 @@ def assert_matmul_fine_grained_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -163,7 +163,7 @@ def assert_matmul_fine_grained_with_default_correctness( B = (torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) if trans_B else torch.rand( K, N, device="cuda", dtype=getattr(torch, in_dtype))) - 0.5 C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) latency = mod.do_bench(mod.func, warmup=25) @@ -217,7 +217,7 @@ def assert_matmul_fine_grained_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -227,7 +227,7 @@ def assert_matmul_fine_grained_apply_config_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -263,7 +263,7 @@ def assert_matmul_weight_propagation_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -284,7 +284,7 @@ def assert_matmul_weight_propagation_with_default_correctness( LB = ladder_permutate(B.cpu()).cuda() - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, LB, C) @@ -337,7 +337,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -358,7 +358,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( LB = ladder_permutate(B.cpu()).cuda() - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, LB, C) @@ -394,7 +394,7 @@ def assert_matmul_int4_fine_grained_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -405,7 +405,7 @@ def assert_matmul_int4_fine_grained_with_default_correctness( compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") print(latency) @@ -459,7 +459,7 @@ def assert_matmul_int4_fine_grained_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -470,7 +470,7 @@ def assert_matmul_int4_fine_grained_apply_config_correctness( compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") print(latency) @@ -509,7 +509,7 @@ def assert_matmul_int4_weight_propagation_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() print(matmul) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -534,7 +534,7 @@ def assert_matmul_int4_weight_propagation_with_default_correctness( LB = ladder_permutate(compressed_B.cpu()).cuda() - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(compressed_A, LB, C) @@ -588,7 +588,7 @@ def assert_matmul_int4_weight_propagation_apply_config__correctness( ) print(matmul) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -613,7 +613,7 @@ def assert_matmul_int4_weight_propagation_apply_config__correctness( LB = ladder_permutate(compressed_B.cpu()).cuda() - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(compressed_A, LB, C) @@ -666,7 +666,7 @@ def assert_matmul_fine_grained_dequant_int4_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -692,7 +692,7 @@ def assert_matmul_fine_grained_dequant_int4_with_default_correctness( compressed_B = (B[:, ::4] & 0x03) + ((B[:, 1::4] & 0x03) << 2) + ((B[:, 2::4] & 0x03) << 4) + ( (B[:, 3::4] & 0x03) << 6) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) print(f"{compressed_B=}") if fast_decoding: lop3_compressed_B = lop3_permutate(compressed_B.cpu()).cuda() @@ -765,7 +765,7 @@ def assert_matmul_fine_grained_dequant_int4_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -791,7 +791,7 @@ def assert_matmul_fine_grained_dequant_int4_apply_config_correctness( compressed_B = (B[:, ::4] & 0x03) + ((B[:, 1::4] & 0x03) << 2) + ((B[:, 2::4] & 0x03) << 4) + ( (B[:, 3::4] & 0x03) << 6) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) print(f"{compressed_B=}") if fast_decoding: lop3_compressed_B = lop3_permutate(compressed_B.cpu()).cuda() @@ -849,7 +849,7 @@ def assert_matmul_weight_transform_dequant_int4_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() print(matmul) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -884,7 +884,7 @@ def assert_matmul_weight_transform_dequant_int4_with_default_correctness( ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) compressed_B_ladder = ladder_permutate(compressed_B.cpu()).cuda() ladder_shape = compressed_B_ladder.shape int2_shape = (ladder_shape[:-1] + (ladder_shape[-1] // 2,)) @@ -970,7 +970,7 @@ def assert_matmul_weight_transform_dequant_int4_apply_config_correctness( ) print(matmul) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -1005,7 +1005,7 @@ def assert_matmul_weight_transform_dequant_int4_apply_config_correctness( ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) compressed_B_ladder = ladder_permutate(compressed_B.cpu()).cuda() ladder_shape = compressed_B_ladder.shape int2_shape = (ladder_shape[:-1] + (ladder_shape[-1] // 2,)) @@ -1078,7 +1078,7 @@ def assert_matmul_blocked_dequant_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() print(matmul) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -1135,7 +1135,7 @@ def assert_matmul_blocked_dequant_with_default_correctness( permuted_inputs.append(inputs[2]) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(*permuted_inputs) @@ -1213,7 +1213,7 @@ def assert_matmul_fine_grained_dequant_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -1268,7 +1268,7 @@ def assert_matmul_fine_grained_dequant_with_default_correctness( permuted_inputs.append(inputs[2]) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(*permuted_inputs) @@ -1344,7 +1344,7 @@ def assert_matmul_weight_transform_dequant_with_default_correctness( ).with_default_config() if verbose: print(matmul) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -1415,7 +1415,7 @@ def assert_matmul_weight_transform_dequant_with_default_correctness( permuted_inputs.append(inputs[2]) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(*permuted_inputs) diff --git a/testing/python/tilelang/test_simplifier.py b/testing/python/tilelang/test_simplifier.py index 96536670a..18613edc9 100644 --- a/testing/python/tilelang/test_simplifier.py +++ b/testing/python/tilelang/test_simplifier.py @@ -1,6 +1,6 @@ import tvm -from tvm import tl -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T def modify( @@ -36,7 +36,7 @@ def main( def test_modify(with_B=False, with_bias=False): tester = modify(with_B=with_B, with_bias=with_bias) mod = tvm.IRModule({tester.attrs["global_symbol"]: tester}) - mod2 = tl.transform.Simplify()(mod) + mod2 = tilelang.transform.Simplify()(mod) assert mod != mod2 @@ -71,11 +71,11 @@ def main( def test_matmul(): func = matmul(1024, 1024, 1024, 128, 128, 32) mod = tvm.IRModule({func.attrs["global_symbol"]: func}) - mod = tl.transform.Simplify()(mod) + mod = tilelang.transform.Simplify()(mod) - rt_mod, params = tl.lower(mod.functions_items()[0][1], runtime_only=False) + rt_mod, params = tilelang.lower(mod.functions_items()[0][1], runtime_only=False) # TODO Profiler only support TensorType, not dynamic variable - profiler = tl.Profiler(rt_mod, params, result_idx=[2]) + profiler = tilelang.Profiler(rt_mod, params, result_idx=[2]) import torch a = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() diff --git a/testing/python/tilelang/test_tilelang_amd_gemm.py b/testing/python/tilelang/test_tilelang_amd_gemm.py index f281f8eb0..20abd415c 100644 --- a/testing/python/tilelang/test_tilelang_amd_gemm.py +++ b/testing/python/tilelang/test_tilelang_amd_gemm.py @@ -3,7 +3,7 @@ from bitblas import tvm as tvm import bitblas.testing -from tvm import tl +from bitblas import tilelang as tilelang def matmul( @@ -27,7 +27,7 @@ def matmul( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) vec_size = 4 * k_pack - import tvm.tl.language as T + import tilelang.language as T @T.prim_func def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( @@ -84,8 +84,8 @@ def run_gemm( num_threads, k_pack=k_pack, ) - mod, params = tl.lower(program) - mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) + mod, params = tilelang.lower(program) + mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) def ref_program(A, B): import torch diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index e3d47b309..2e4873f89 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -6,8 +6,8 @@ import bitblas.testing from bitblas import tvm as tvm from tvm import DataType -from tvm import tl as TL -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from bitblas.quantization import _tir_packed_to_unsigned_convert from bitblas.tl.utils import make_mma_swizzle_layout as make_swizzle_layout from bitblas.tl.mma_macro_generator import ( @@ -45,7 +45,7 @@ def matmul( local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits local_size_compressed = local_size // num_elems_per_byte - import tvm.tl.language as T + import tilelang.language as T @T.prim_func def main( @@ -123,8 +123,8 @@ def run_gemm( num_threads, ) - mod, params = TL.lower(program) - mod = TL.Profiler(mod, params, [2], TL.TensorSupplyType.Integer) + mod, params = tilelang.lower(program) + mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) out = mod.run_once() assert out is not None @@ -367,7 +367,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -406,7 +406,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct QLB = ladder_permutate(qB.cpu()).cuda() QLB = lop3_permutate(QLB.cpu()).cuda() - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, QLB, C) diff --git a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py index ae63cce9e..ae027cbf4 100644 --- a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py +++ b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py @@ -6,8 +6,8 @@ from bitblas import tvm as tvm import bitblas.testing from tvm import DataType -from tvm import tl as TL -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.tl.mma_macro_generator import (TensorCoreIntrinEmitter) @@ -178,7 +178,7 @@ def main( def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_macro(N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -188,7 +188,7 @@ def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -217,7 +217,7 @@ def tl_matmul_block( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tvm.tl.language as T + import tilelang.language as T @T.prim_func def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( @@ -271,13 +271,13 @@ def assert_tl_matmul_block_correctness( num_stages, num_threads, ) - mod, params = TL.lower(program) + mod, params = tilelang.lower(program) A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) def ref_program(A, B): @@ -318,7 +318,7 @@ def tl_matmul_block_all_dynamic( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tvm.tl.language as T + import tilelang.language as T @T.prim_func def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( @@ -370,7 +370,7 @@ def assert_tl_matmul_block_all_dynamic_correctness( num_stages, num_threads, ) - mod, params = TL.lower(program) + mod, params = tilelang.lower(program) if trans_A: A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) else: @@ -381,7 +381,7 @@ def assert_tl_matmul_block_all_dynamic_correctness( B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) def ref_program(A, B): diff --git a/testing/python/tilelang/test_tilelang_flash_atten.py b/testing/python/tilelang/test_tilelang_flash_atten.py index 3b9e33440..2c1c834ee 100644 --- a/testing/python/tilelang/test_tilelang_flash_atten.py +++ b/testing/python/tilelang/test_tilelang_flash_atten.py @@ -3,9 +3,9 @@ import bitblas import bitblas.testing from bitblas import tvm as tvm -from tvm import tl -import tvm.tl.language as T -from tvm.tl.autotuner import * +from bitblas import tilelang as tilelang +import tilelang.language as T +from tilelang.autotuner import * from functools import partial import itertools import torch @@ -66,8 +66,8 @@ def flashattn_tilelang(batch, heads, seq_len, dim, trans_K, dtypeQKV, dtypeAccu, num_stages=num_stages, is_causal=is_causal, ) - mod, params = tl.lower(tl_prim_func) - mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) + mod, params = tilelang.lower(tl_prim_func) + mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) from flash_attn.flash_attn_interface import flash_attn_func # TODO Now hack to internal function get the same input, may need to modify 3rdparty:tvm.tl.utils ins = mod._get_inputs() @@ -123,7 +123,7 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -177,8 +177,8 @@ def main( return main - mod, params = tl.lower(kernel()) - mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) + mod, params = tilelang.lower(kernel()) + mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.01, atol=0.01) @@ -206,7 +206,7 @@ def flashattn_autotune(batch, heads, seq_len, dim, is_causal): ) @jit( out_idx=[3], - supply_type=tl.TensorSupplyType.Normal, + supply_type=tilelang.TensorSupplyType.Normal, ref_prog=partial(ref_program, causal=is_causal), rtol=0.01, atol=0.01, @@ -239,7 +239,7 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -344,7 +344,7 @@ def main( scores_sum = T.alloc_fragment([block_M], accum_dtype) logsum = T.alloc_fragment([block_M], accum_dtype) - T.annotate_layout({Q_shared: tl.layout.make_swizzled_layout(Q_shared)}) + T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)}) T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) T.fill(acc_o, 0) T.fill(logsum, 0) @@ -398,8 +398,8 @@ def main( return main - mod, params = tl.lower(kernel()) - mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) + mod, params = tilelang.lower(kernel()) + mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.1, atol=0.1) diff --git a/testing/python/tilelang/test_tilelang_gemm.py b/testing/python/tilelang/test_tilelang_gemm.py index a4722eb99..bd26fcc1f 100644 --- a/testing/python/tilelang/test_tilelang_gemm.py +++ b/testing/python/tilelang/test_tilelang_gemm.py @@ -3,7 +3,7 @@ from bitblas import tvm as tvm import bitblas.testing -from tvm import tl +from bitblas import tilelang as tilelang def matmul( @@ -26,7 +26,7 @@ def matmul( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tvm.tl.language as T + import tilelang.language as T @T.prim_func def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( @@ -81,8 +81,8 @@ def run_gemm( num_stages, num_threads, ) - mod, params = tl.lower(program) - mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) + mod, params = tilelang.lower(program) + mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) def ref_program(A, B): import torch diff --git a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py index 42e449056..b32fd7833 100644 --- a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py +++ b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py @@ -6,8 +6,8 @@ from bitblas import tvm as tvm import bitblas.testing from tvm import DataType -from tvm import tl as TL -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from bitblas.tl.utils import ( make_mma_swizzle_layout as make_swizzle_layout,) @@ -173,7 +173,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -184,7 +184,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): compressed_A = (A[:, ::2] & 0x0F) + ((A[:, 1::2] & 0x0F) << 4) compressed_B = (B[:, ::2] & 0x0F) + ((B[:, 1::2] & 0x0F) << 4) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(compressed_A, compressed_B, C) print(C) latency = mod.do_bench(mod.func, warmup=25, profiler="tvm") @@ -368,7 +368,7 @@ def main( def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -391,7 +391,7 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) LB = ladder_permutate(compressed_B.cpu()).cuda() mod(compressed_A, LB, C) diff --git a/testing/python/tilelang/test_tilelang_gemm_simt.py b/testing/python/tilelang/test_tilelang_gemm_simt.py index 67e2f70e2..33e5abae6 100644 --- a/testing/python/tilelang/test_tilelang_gemm_simt.py +++ b/testing/python/tilelang/test_tilelang_gemm_simt.py @@ -6,8 +6,8 @@ from bitblas import tvm as tvm import bitblas.testing from tvm import DataType -from tvm import tl as TL -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.base import simplify_prim_func @@ -142,7 +142,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_simt(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() print(src_code) # src_code is the generated cuda source @@ -157,7 +157,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) diff --git a/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py b/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py index 2f44aea85..b1f16c207 100644 --- a/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_mfma_macro_gemm.py @@ -5,8 +5,8 @@ import torch.backends from bitblas import tvm as tvm import bitblas.testing -from tvm import tl as TL -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from bitblas.tl.utils import make_mfma_swizzle_layout as make_swizzle_layout from bitblas.tl.mfma_macro_generator import ( MatrixCoreIntrinEmitter,) @@ -172,7 +172,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32"): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -186,7 +186,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) diff --git a/testing/python/tilelang/test_tilelang_mma_macro_gemm.py b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py index 660aaad89..dbdfd1034 100644 --- a/testing/python/tilelang/test_tilelang_mma_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_mma_macro_gemm.py @@ -6,8 +6,8 @@ from bitblas import tvm as tvm import bitblas.testing from tvm import DataType -from tvm import tl as TL -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.tl.mma_macro_generator import ( TensorCoreIntrinEmitter, @@ -186,7 +186,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -200,7 +200,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -387,7 +387,7 @@ def main( def assert_tl_matmul_with_block_reduce_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_with_block_reduce(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -397,7 +397,7 @@ def assert_tl_matmul_with_block_reduce_correctness(M, N, K, in_dtype, out_dtype, B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -564,7 +564,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness(M, N, K, in_d matmul = tl_matmul_with_ladder_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -583,7 +583,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness(M, N, K, in_d ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) LB = ladder_permutate(B.cpu()).cuda() @@ -824,7 +824,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -863,7 +863,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct QLB = ladder_permutate(qB.cpu()).cuda() QLB = lop3_permutate(QLB.cpu()).cuda() - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, QLB, C) @@ -1035,7 +1035,7 @@ def assert_tl_matmul_with_ladder_input_weight_transform_correctness(M, N, K, in_ matmul = tl_matmul_with_ladder_input_weight_transform(M, N, K, in_dtype, out_dtype, accum_dtype, transform_a, transform_b) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -1068,7 +1068,7 @@ def assert_tl_matmul_with_ladder_input_weight_transform_correctness(M, N, K, in_ ladder_permutate_b = bitblas.ops.LadderPermutate(ladder_permutate_config_B) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) LA = ladder_permutate_a(A.cpu()).cuda() LB = ladder_permutate_b(B.cpu()).cuda()