From 364b00e9c49b8b6d42d305ff3aa65dd8728db787 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Thu, 11 Dec 2025 10:08:50 -0500 Subject: [PATCH 1/5] Test on Python 3.14 (temporary) --- .github/workflows/test-runner.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-runner.yml b/.github/workflows/test-runner.yml index 12c32b828..ab805e4a9 100644 --- a/.github/workflows/test-runner.yml +++ b/.github/workflows/test-runner.yml @@ -196,7 +196,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.14' # Windows: Setup MSVC for torch.compile - name: Setup MSVC From ff5a6b4f34a78cecc2ee8e966050833a5a9c6d99 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Thu, 11 Dec 2025 10:25:59 -0500 Subject: [PATCH 2/5] Fix: Python 3.14 / torch.compile compatibility --- bitsandbytes/backends/default/ops.py | 31 ++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index a0f0d2a34..707aeb3c3 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from functools import wraps from math import prod, sqrt from typing import Optional @@ -8,6 +9,32 @@ from ..utils import CODE +def _try_torch_compile(func=None, **compile_kwargs): + """ + Wrapper around torch.compile that falls back to the original function if compilation fails. + """ + + def decorator(fn): + try: + compiled_fn = torch.compile(fn, **compile_kwargs) + + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return compiled_fn(*args, **kwargs) + except Exception: + return fn(*args, **kwargs) + + return wrapper + except Exception: + return fn + + if func is None: + return decorator + else: + return decorator(func) + + @register_kernel("bitsandbytes::int8_mm_dequant", "default") def _( A: torch.Tensor, @@ -332,7 +359,7 @@ def _( } -@torch.compile +@_try_torch_compile def _optimizer_precondition_32bit( g: torch.Tensor, p: torch.Tensor, @@ -393,7 +420,7 @@ def _optimizer_precondition_32bit( unorm_vec.add_(total_norm) -@torch.compile +@_try_torch_compile def _optimizer_update_32bit( g: torch.Tensor, p: torch.Tensor, From 8418e52050341a53f429cd7a4b1fe94abfd569ac Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Thu, 11 Dec 2025 10:58:43 -0500 Subject: [PATCH 3/5] Skip torch.compile test on Python 3.14 and torch < 2.10 (not supported) --- tests/test_linear4bit.py | 2 ++ tests/test_linear8bitlt.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 398fb83d3..3dd231e09 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -2,6 +2,7 @@ import os import pickle import platform +import sys from tempfile import TemporaryDirectory import pytest @@ -320,6 +321,7 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s @pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) @pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) @pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") +@pytest.mark.skipif(torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10") def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode): if device == "hpu" and not is_supported_on_hpu(quant_type): pytest.skip("This configuration is not supported on HPU.") diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index a0725d605..2f391f0d3 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -3,6 +3,7 @@ import os import pickle import platform +import sys from tempfile import TemporaryDirectory import pytest @@ -234,6 +235,7 @@ def test_linear8bit_serialization(linear8bit): @pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) @pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) @pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") +@pytest.mark.skipif(torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10") @pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet") def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): if device == "cuda" and platform.system() == "Windows": From 2ea824f5ac40bdc89e104bc33ec1d290c6d764e4 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Thu, 11 Dec 2025 12:10:33 -0500 Subject: [PATCH 4/5] restore --- .github/workflows/test-runner.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-runner.yml b/.github/workflows/test-runner.yml index ab805e4a9..12c32b828 100644 --- a/.github/workflows/test-runner.yml +++ b/.github/workflows/test-runner.yml @@ -196,7 +196,7 @@ jobs: - name: Setup Python uses: actions/setup-python@v5 with: - python-version: '3.14' + python-version: '3.10' # Windows: Setup MSVC for torch.compile - name: Setup MSVC From 3ba976a8f7472437155d684462f913384b3fe0b9 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Thu, 11 Dec 2025 12:35:12 -0500 Subject: [PATCH 5/5] Format --- tests/test_linear4bit.py | 4 +++- tests/test_linear8bitlt.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index 3dd231e09..2b92ee4f1 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -321,7 +321,9 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s @pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) @pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) @pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") -@pytest.mark.skipif(torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10") +@pytest.mark.skipif( + torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10" +) def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode): if device == "hpu" and not is_supported_on_hpu(quant_type): pytest.skip("This configuration is not supported on HPU.") diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 2f391f0d3..83f207d42 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -235,7 +235,9 @@ def test_linear8bit_serialization(linear8bit): @pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) @pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) @pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") -@pytest.mark.skipif(torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10") +@pytest.mark.skipif( + torch.__version__ < (2, 10) and sys.version_info >= (3, 14), reason="Not supported in Python 3.14 until torch 2.10" +) @pytest.mark.skipif(ROCM_WARP_SIZE_64, reason="this test is not supported on ROCm yet") def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): if device == "cuda" and platform.system() == "Windows":