diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 4c35fdb2902f..6a7d2b27765d 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -118,10 +118,20 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // If the numerator's lower bound is known, express the floordiv // in terms of truncdiv using only positive operands. - arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a); - if (const_int_bound->min_value < 0 && - const_int_bound->min_value > - -(Downcast(tvm::max_value(op->a->dtype.element_of()))->value)) { + + // The optimization below rewrites expressions involving `-a_min + (b - 1)`. + // Without proper bounds checking, this expression may overflow the dtype + // maximum, leading to non-equivalent transformations. + // To ensure safety, we require: + // b_max - a_min <= max_value_of_dtype + 1 + // This provides a conservative upper bound that prevents overflow and + // preserves the original semantics. + arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(op->a); + arith::ConstIntBound const_int_bound_b = analyzer_->const_int_bound(op->b); + const int64_t max_value_of_dtype = + Downcast(tvm::max_value(op->a->dtype.element_of()))->value; + if (const_int_bound_a->min_value < 0 && + const_int_bound_b->max_value - const_int_bound_a->min_value <= max_value_of_dtype + 1) { // The goal is to write floordiv(a,b) in terms of truncdiv, without using // negative operands. // @@ -152,7 +162,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // floordiv(a,b) // == floordiv(a + b*c, b) - c // == truncdiv(a + b*c, b) - c - IntImm min(op->a->dtype.element_of(), const_int_bound->min_value); + IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value); PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b); PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv); return truncdiv(offset_numerator, op->b) - ceildiv; @@ -214,10 +224,20 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // If the numerator's lower bound is known, express the floormod // in terms of truncmod using only positive operands. - arith::ConstIntBound const_int_bound = analyzer_->const_int_bound(op->a); - if (const_int_bound->min_value < 0 && - const_int_bound->min_value > - -(Downcast(tvm::max_value(op->a->dtype.element_of()))->value)) { + + // The optimization below rewrites expressions involving `-a_min + (b - 1)`. + // Without proper bounds checking, this expression may overflow the dtype + // maximum, leading to non-equivalent transformations. + // To ensure safety, we require: + // b_max - a_min <= max_value_of_dtype + 1 + // This provides a conservative upper bound that prevents overflow and + // preserves the original semantics. + arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(op->a); + arith::ConstIntBound const_int_bound_b = analyzer_->const_int_bound(op->b); + const int64_t max_value_of_dtype = + Downcast(tvm::max_value(op->a->dtype.element_of()))->value; + if (const_int_bound_a->min_value < 0 && + const_int_bound_b->max_value - const_int_bound_a->min_value <= max_value_of_dtype + 1) { // The goal is to write floormod(a,b) in terms of truncdiv and truncmod, // without using negative operands. // @@ -247,7 +267,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // floormod(a,b) // == floormod(a + b*c, b) // == truncmod(a + b*c, b) - IntImm min(op->a->dtype.element_of(), const_int_bound->min_value); + IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value); PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b); PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv); return truncmod(offset_numerator, op->b); diff --git a/tests/python/tir-transform/test_tir_transform_lower_intrin.py b/tests/python/tir-transform/test_tir_transform_lower_intrin.py index 864b24bc0f51..63f37e6f4179 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_intrin.py +++ b/tests/python/tir-transform/test_tir_transform_lower_intrin.py @@ -92,6 +92,12 @@ def test_lower_floordiv(): # const power of two res = lower_intrin([x, y], tvm.te.floordiv(x, tvm.tir.const(8, dtype=dtype))) check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a // b) + # floordiv(x + m, k), m and k are positive constant. 2 <= m <= k-1. + res = lower_intrin( + [x, y], + tvm.te.floordiv(x + tvm.tir.const(4, dtype=dtype), tvm.tir.const(5, dtype=dtype)), + ) + check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) // b) @tvm.testing.requires_llvm @@ -115,6 +121,12 @@ def test_lower_floormod(): # const power of two res = lower_intrin([x, y], tvm.te.floormod(x, tvm.tir.const(8, dtype=dtype))) check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a % b) + # floormod(x + m, k), m and k are positive constant. 2 <= m <= k-1. + res = lower_intrin( + [x, y], + tvm.te.floormod(x + tvm.tir.const(4, dtype=dtype), tvm.tir.const(5, dtype=dtype)), + ) + check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) % b) if __name__ == "__main__":