From d88e041ff4948ec7162b98f6497bb021a86bcc2f Mon Sep 17 00:00:00 2001 From: guocj <2281193216@qq.com> Date: Thu, 29 Jan 2026 15:22:11 +0800 Subject: [PATCH] [BugFix][TIR] Fix incorrect optimization when lowering floordiv and floormod This patch fixes an issue in the LowerIntrin pass where incorrect optimizations were applied to floordiv and floormod operations. The root cause is that the pass attempts to find an equivalent representation for floordiv(a, b) by calculating the expression (op->b - 1) - a_min. This expression, when subjected to constant folding, can potentially overflow the range of int32 or int16. When this overflow occurs, the transformation becomes invalid and no longer equivalent to the original operation. To fix this, we enhanced the condition under which the transformation is applied. The new condition ensures that the transformation is only performed when (b_max - a_min) is less than INT_MAX + 2. If this condition is not met, the transformation is skipped and the common lowering steps are followed to ensure correctness. A regression test has been added to cover this case. Fixes #18684 add comment --- src/tir/transforms/lower_intrin.cc | 40 ++++++++++++++----- .../test_tir_transform_lower_intrin.py | 12 ++++++ 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 4c35fdb2902f..6a7d2b27765d 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -118,10 +118,20 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // If the numerator's lower bound is known, express the floordiv // in terms of truncdiv using only positive operands. - arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a); - if (const_int_bound->min_value < 0 && - const_int_bound->min_value > - -(Downcast(tvm::max_value(op->a->dtype.element_of()))->value)) { + + // The optimization below rewrites expressions involving `-a_min + (b - 1)`. + // Without proper bounds checking, this expression may overflow the dtype + // maximum, leading to non-equivalent transformations. + // To ensure safety, we require: + // b_max - a_min <= max_value_of_dtype + 1 + // This provides a conservative upper bound that prevents overflow and + // preserves the original semantics. + arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(op->a); + arith::ConstIntBound const_int_bound_b = analyzer_->const_int_bound(op->b); + const int64_t max_value_of_dtype = + Downcast(tvm::max_value(op->a->dtype.element_of()))->value; + if (const_int_bound_a->min_value < 0 && + const_int_bound_b->max_value - const_int_bound_a->min_value <= max_value_of_dtype + 1) { // The goal is to write floordiv(a,b) in terms of truncdiv, without using // negative operands. // @@ -152,7 +162,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // floordiv(a,b) // == floordiv(a + b*c, b) - c // == truncdiv(a + b*c, b) - c - IntImm min(op->a->dtype.element_of(), const_int_bound->min_value); + IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value); PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b); PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv); return truncdiv(offset_numerator, op->b) - ceildiv; @@ -214,10 +224,20 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // If the numerator's lower bound is known, express the floormod // in terms of truncmod using only positive operands. - arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a); - if (const_int_bound->min_value < 0 && - const_int_bound->min_value > - -(Downcast(tvm::max_value(op->a->dtype.element_of()))->value)) { + + // The optimization below rewrites expressions involving `-a_min + (b - 1)`. + // Without proper bounds checking, this expression may overflow the dtype + // maximum, leading to non-equivalent transformations. + // To ensure safety, we require: + // b_max - a_min <= max_value_of_dtype + 1 + // This provides a conservative upper bound that prevents overflow and + // preserves the original semantics. + arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(op->a); + arith::ConstIntBound const_int_bound_b = analyzer_->const_int_bound(op->b); + const int64_t max_value_of_dtype = + Downcast(tvm::max_value(op->a->dtype.element_of()))->value; + if (const_int_bound_a->min_value < 0 && + const_int_bound_b->max_value - const_int_bound_a->min_value <= max_value_of_dtype + 1) { // The goal is to write floormod(a,b) in terms of truncdiv and truncmod, // without using negative operands. // @@ -247,7 +267,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // floormod(a,b) // == floormod(a + b*c, b) // == truncmod(a + b*c, b) - IntImm min(op->a->dtype.element_of(), const_int_bound->min_value); + IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value); PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b); PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv); return truncmod(offset_numerator, op->b); diff --git a/tests/python/tir-transform/test_tir_transform_lower_intrin.py b/tests/python/tir-transform/test_tir_transform_lower_intrin.py index 864b24bc0f51..63f37e6f4179 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_intrin.py +++ b/tests/python/tir-transform/test_tir_transform_lower_intrin.py @@ -92,6 +92,12 @@ def test_lower_floordiv(): # const power of two res = lower_intrin([x, y], tvm.te.floordiv(x, tvm.tir.const(8, dtype=dtype))) check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a // b) + # floordiv(x + m, k), m and k are positive constant. 2 <= m <= k-1. + res = lower_intrin( + [x, y], + tvm.te.floordiv(x + tvm.tir.const(4, dtype=dtype), tvm.tir.const(5, dtype=dtype)), + ) + check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) // b) @tvm.testing.requires_llvm @@ -115,6 +121,12 @@ def test_lower_floormod(): # const power of two res = lower_intrin([x, y], tvm.te.floormod(x, tvm.tir.const(8, dtype=dtype))) check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a % b) + # floormod(x + m, k), m and k are positive constant. 2 <= m <= k-1. + res = lower_intrin( + [x, y], + tvm.te.floormod(x + tvm.tir.const(4, dtype=dtype), tvm.tir.const(5, dtype=dtype)), + ) + check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) % b) if __name__ == "__main__":