diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index a0f0d2a34..707aeb3c3 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from functools import wraps from math import prod, sqrt from typing import Optional @@ -8,6 +9,32 @@ from ..utils import CODE +def _try_torch_compile(func=None, **compile_kwargs): + """ + Wrapper around torch.compile that falls back to the original function if compilation fails. + """ + + def decorator(fn): + try: + compiled_fn = torch.compile(fn, **compile_kwargs) + + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return compiled_fn(*args, **kwargs) + except Exception: + return fn(*args, **kwargs) + + return wrapper + except Exception: + return fn + + if func is None: + return decorator + else: + return decorator(func) + + @register_kernel("bitsandbytes::int8_mm_dequant", "default") def _( A: torch.Tensor, @@ -332,7 +359,7 @@ def _( } -@torch.compile +@_try_torch_compile def _optimizer_precondition_32bit( g: torch.Tensor, p: torch.Tensor, @@ -393,7 +420,7 @@ def _optimizer_precondition_32bit( unorm_vec.add_(total_norm) -@torch.compile +@_try_torch_compile def _optimizer_update_32bit( g: torch.Tensor, p: torch.Tensor, diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 398fb83d3..2b92ee4f1 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -2,6 +2,7 @@ import os import pickle import platform +import sys from tempfile import TemporaryDirectory import pytest @@ -320,6 +321,9 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s @pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) @pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) @pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") +@pytest.mark.skipif( + torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10" +) def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode): if device == "hpu" and not is_supported_on_hpu(quant_type): pytest.skip("This configuration is not supported on HPU.") diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index a0725d605..83f207d42 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -3,6 +3,7 @@ import os import pickle import platform +import sys from tempfile import TemporaryDirectory import pytest @@ -234,6 +235,9 @@ def test_linear8bit_serialization(linear8bit): @pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) @pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) @pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") +@pytest.mark.skipif( + torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10" +) @pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet") def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): if device == "cuda" and platform.system() == "Windows":