Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
30 changes: 28 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,32 @@ elseif(BUILD_MPS)
add_compile_definitions(BUILD_MPS)
file(MAKE_DIRECTORY "build")
add_custom_command(OUTPUT "bitsandbytes/bitsandbytes.metallib"
COMMAND xcrun metal -c -o "build/bitsandbytes.air" ${METAL_FILES}
COMMAND xcrun metal -c -g -frecord-sources -gline-tables-only -o "build/bitsandbytes.air" ${METAL_FILES}
COMMAND xcrun metallib "build/bitsandbytes.air" -o "bitsandbytes/bitsandbytes.metallib"
DEPENDS "${METAL_FILES}"
COMMENT "Compiling Metal kernels"
VERBATIM)
add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib")
if(NOT Torch_DIR)
find_package(Python3 COMPONENTS Interpreter)
if(Python3_EXECUTABLE)
execute_process(
COMMAND "${Python3_EXECUTABLE}" -c "import torch; import sys; sys.stdout.write(torch.utils.cmake_prefix_path)"
OUTPUT_VARIABLE TORCH_CMAKE_PREFIX_PATH
ERROR_VARIABLE TORCH_DETECT_ERROR
RESULT_VARIABLE TORCH_DETECT_RESULT
)
if(TORCH_DETECT_RESULT EQUAL 0 AND TORCH_CMAKE_PREFIX_PATH)
list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PREFIX_PATH}")
endif()
endif()
endif()
find_package(Torch REQUIRED)
if(TORCH_CXX_FLAGS)
string(APPEND CMAKE_CXX_FLAGS " ${TORCH_CXX_FLAGS}")
endif()
set(BNB_TORCH_INCLUDE_DIRS ${TORCH_INCLUDE_DIRS})
set(BNB_TORCH_LIBRARIES ${TORCH_LIBRARIES})
elseif(BUILD_XPU)
list(APPEND SRC_FILES ${XPU_FILES})
string(APPEND BNB_OUTPUT_NAME "_xpu")
Expand Down Expand Up @@ -351,7 +371,13 @@ if(BUILD_HIP)
endif()
if(BUILD_MPS)
add_dependencies(bitsandbytes metallib)
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
target_link_libraries(bitsandbytes PRIVATE objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
if(BNB_TORCH_INCLUDE_DIRS)
target_include_directories(bitsandbytes PRIVATE ${BNB_TORCH_INCLUDE_DIRS})
endif()
if(BNB_TORCH_LIBRARIES)
target_link_libraries(bitsandbytes PRIVATE ${BNB_TORCH_LIBRARIES})
endif()
endif()
if(BUILD_XPU)
set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'")
Expand Down
30 changes: 30 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Dispatcher shimming the editable layout.

When this repository is used via ``pip install -e .`` the real Python
package lives under ``bitsandbytes/bitsandbytes``. Importing from the
workspace root (e.g. running scripts from ``.../ai/kernels``) would
otherwise resolve to this outer directory, yielding a namespace module
with no attributes. Import the inner package eagerly and mirror its
symbols so ``import bitsandbytes`` always behaves the same as the
installed wheel.
"""

from __future__ import annotations

import importlib
from types import ModuleType

_inner: ModuleType = importlib.import_module(".bitsandbytes", __name__)

Check failure on line 17 in __init__.py

View workflow job for this annotation

GitHub Actions / Lint

Ruff (F821)

__init__.py:17:9: F821 Undefined name `ModuleType`

# Copy dunder metadata expected by consumers.
for _name in ("__all__", "__doc__", "__file__", "__loader__", "__path__", "__spec__", "__version__"):
if hasattr(_inner, _name):
globals()[_name] = getattr(_inner, _name)

# Re-export public symbols while leaving dunders alone.
for _name, _value in vars(_inner).items():
if not _name.startswith("__"):
globals()[_name] = _value

del _inner, _name, _value, ModuleType, importlib

5 changes: 4 additions & 1 deletion bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch

from . import _ops, research, utils
from . import _ops, nn, research, utils
from .autograd._functions import (
MatmulLtState,
matmul,
Expand Down Expand Up @@ -38,6 +38,9 @@
if hasattr(torch, "xpu") and torch.xpu.is_available():
from .backends.xpu import ops as xpu_ops

if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
from .backends.mps import ops as mps_ops

if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"):
# In case not automatically imported
import habana_frameworks.torch
Expand Down
12 changes: 8 additions & 4 deletions bitsandbytes/backends/default/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,7 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
return out


@register_kernel("bitsandbytes::quantize_4bit", "default")
def _(
def _quantize_4bit_impl(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
Expand Down Expand Up @@ -232,6 +231,13 @@ def _(
return packed, absmax.float()


@register_kernel("bitsandbytes::quantize_4bit", "default")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
return _quantize_4bit_impl(A, blocksize, quant_type, quant_storage)


def _dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
Expand All @@ -243,7 +249,6 @@ def _dequantize_4bit_impl(
# Enable non uint8 dtype
if A.dtype != torch.uint8:
A = A.view(torch.uint8)

A = A.reshape(-1)
# Map nf4 to [-1, 1]
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
Expand Down Expand Up @@ -290,7 +295,6 @@ def _(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)

return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)


Expand Down
2 changes: 2 additions & 0 deletions bitsandbytes/backends/mps/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# MPS backend registrations are defined in ops.py

202 changes: 202 additions & 0 deletions bitsandbytes/backends/mps/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from collections.abc import Sequence
from typing import Optional

import ctypes as ct
from ctypes import _CFuncPtr
import torch

from ..._ops import register_kernel
from ...cextension import lib
from ..default.ops import _dequantize_4bit_impl, _quantize_4bit_impl
from ..utils import CODE
from .shim import MPSTensorShim#, configure_mps_blockwise_kernel


def _check_mps_device(tensor: torch.Tensor, name: str) -> None:
torch._check(
tensor.device.type == "mps",
lambda: f"{name} must live on an MPS device for the MPS backend, got {tensor.device.type}",
)


def _supports_dtype(dtype: torch.dtype) -> bool:
return dtype in (torch.float16, torch.float32, torch.bfloat16)


def _kernel_dtype(dtype: torch.dtype) -> torch.dtype:
if dtype == torch.bfloat16:
return torch.float32
return dtype


def _resolve_quant_fn(dtype: torch.dtype, quant_type: str) -> Optional[_CFuncPtr]:
try:
if dtype == torch.float16:
fn = getattr(
lib,
"cquantize_blockwise_fp16_fp4" if quant_type == "fp4" else "cquantize_blockwise_fp16_nf4",
)
# configure_mps_blockwise_kernel(fn)
return fn
if dtype == torch.float32:
fn = getattr(
lib,
"cquantize_blockwise_fp32_fp4" if quant_type == "fp4" else "cquantize_blockwise_fp32_nf4",
)
# configure_mps_blockwise_kernel(fn)
return fn
except AttributeError:
return None
return None


def _resolve_dequant_fn(dtype: torch.dtype, quant_type: str) -> Optional[_CFuncPtr]:
try:
if dtype == torch.float16:
fn = getattr(
lib,
"cdequantize_blockwise_fp16_fp4" if quant_type == "fp4" else "cdequantize_blockwise_fp16_nf4",
)
# configure_mps_blockwise_kernel(fn)
return fn
if dtype == torch.float32:
fn = getattr(
lib,
"cdequantize_blockwise_fp32_fp4" if quant_type == "fp4" else "cdequantize_blockwise_fp32_nf4",
)
# configure_mps_blockwise_kernel(fn)
return fn
except AttributeError:
return None
return None


def _quantize_4bit_native(
A: torch.Tensor,
blocksize: int,
quant_type: str,
quant_storage: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor] | None:
if quant_storage != torch.uint8 or not _supports_dtype(A.dtype):
return None

kernel_dtype = _kernel_dtype(A.dtype)
fn = _resolve_quant_fn(kernel_dtype, quant_type)
if fn is None:
return None

if kernel_dtype != A.dtype:
A_kernel = A.to(kernel_dtype)
else:
A_kernel = A

n = A.numel()
blocks = -(n // -blocksize)
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage)

input_shim = MPSTensorShim.from_tensor(A_kernel)
absmax_shim = MPSTensorShim.from_tensor(absmax)
out_shim = MPSTensorShim.from_tensor(out)

fn(
input_shim.struct,
absmax_shim.struct,
out_shim.struct,
ct.c_int32(blocksize),
ct.c_int32(n),
)
return out, absmax


def _dequantize_4bit_native(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
dtype: torch.dtype,
out: torch.Tensor,
) -> bool:
if A.dtype != torch.uint8 or not _supports_dtype(dtype):
return False

_check_mps_device(absmax, "absmax")
kernel_dtype = _kernel_dtype(dtype)
fn = _resolve_dequant_fn(kernel_dtype, quant_type)
if fn is None:
return False

packed_shim = MPSTensorShim.from_tensor(A)
absmax_shim = MPSTensorShim.from_tensor(absmax)
if kernel_dtype != dtype:
work_out = torch.empty_like(out, dtype=kernel_dtype)
else:
work_out = out
out_shim = MPSTensorShim.from_tensor(work_out)

fn(
packed_shim.struct,
absmax_shim.struct,
out_shim.struct,
ct.c_int32(blocksize),
ct.c_int32(out.numel()),
)

if work_out is not out:
out.copy_(work_out.to(dtype))

return True


@register_kernel("bitsandbytes::quantize_4bit", "mps")
def _(
A: torch.Tensor,
blocksize: int,
quant_type: str,
quant_storage: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
_check_mps_device(A, "A")
# result = _quantize_4bit_native(A, blocksize, quant_type, quant_storage)
# if result is not None:
# return result
return _quantize_4bit_impl(A, blocksize, quant_type, quant_storage)


@register_kernel("bitsandbytes::dequantize_4bit", "mps")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
_check_mps_device(A, "A")
_check_mps_device(absmax, "absmax")
out = torch.empty(shape, dtype=dtype, device=A.device)
if _dequantize_4bit_native(A, absmax, blocksize, quant_type, dtype, out):
return out
else:
raise RuntimeError("Failed to dequantize 4bit on MPS")
return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)


@register_kernel("bitsandbytes::dequantize_4bit.out", "mps")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
_check_mps_device(A, "A")
_check_mps_device(out, "out")
_check_mps_device(absmax, "absmax")
torch._check(out.shape == tuple(shape), lambda: f"Expected out.shape == {tuple(shape)}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")

if not _dequantize_4bit_native(A, absmax, blocksize, quant_type, dtype, out):
result = _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
out.copy_(result)
Loading
Loading