Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions bitsandbytes/backends/default/ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Sequence
from functools import wraps
from math import prod, sqrt
from typing import Optional

Expand All @@ -8,6 +9,32 @@
from ..utils import CODE


def _try_torch_compile(func=None, **compile_kwargs):
"""
Wrapper around torch.compile that falls back to the original function if compilation fails.
"""

def decorator(fn):
try:
compiled_fn = torch.compile(fn, **compile_kwargs)

@wraps(fn)
def wrapper(*args, **kwargs):
try:
return compiled_fn(*args, **kwargs)
except Exception:
return fn(*args, **kwargs)

return wrapper
except Exception:
return fn

if func is None:
return decorator
else:
return decorator(func)


@register_kernel("bitsandbytes::int8_mm_dequant", "default")
def _(
A: torch.Tensor,
Expand Down Expand Up @@ -332,7 +359,7 @@ def _(
}


@torch.compile
@_try_torch_compile
def _optimizer_precondition_32bit(
g: torch.Tensor,
p: torch.Tensor,
Expand Down Expand Up @@ -393,7 +420,7 @@ def _optimizer_precondition_32bit(
unorm_vec.add_(total_norm)


@torch.compile
@_try_torch_compile
def _optimizer_update_32bit(
g: torch.Tensor,
p: torch.Tensor,
Expand Down
4 changes: 4 additions & 0 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import pickle
import platform
import sys
from tempfile import TemporaryDirectory

import pytest
Expand Down Expand Up @@ -320,6 +321,9 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
@pytest.mark.skipif(
torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10"
)
def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode):
if device == "hpu" and not is_supported_on_hpu(quant_type):
pytest.skip("This configuration is not supported on HPU.")
Expand Down
4 changes: 4 additions & 0 deletions tests/test_linear8bitlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import pickle
import platform
import sys
from tempfile import TemporaryDirectory

import pytest
Expand Down Expand Up @@ -234,6 +235,9 @@ def test_linear8bit_serialization(linear8bit):
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
@pytest.mark.skipif(
torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10"
)
@pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet")
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
if device == "cuda" and platform.system() == "Windows":
Expand Down