diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 6a7d2b27765d..ef844d9e0505 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -115,59 +115,15 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { if (analyzer_->CanProveGreaterEqual(op->a, 0) || analyzer_->CanProveGreaterEqual(e, 0)) { return truncdiv(op->a, op->b); } - - // If the numerator's lower bound is known, express the floordiv - // in terms of truncdiv using only positive operands. - - // The optimization below rewrites expressions involving `-a_min + (b - 1)`. - // Without proper bounds checking, this expression may overflow the dtype - // maximum, leading to non-equivalent transformations. - // To ensure safety, we require: - // b_max - a_min <= max_value_of_dtype + 1 - // This provides a conservative upper bound that prevents overflow and - // preserves the original semantics. - arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(op->a); - arith::ConstIntBound const_int_bound_b = analyzer_->const_int_bound(op->b); - const int64_t max_value_of_dtype = - Downcast(tvm::max_value(op->a->dtype.element_of()))->value; - if (const_int_bound_a->min_value < 0 && - const_int_bound_b->max_value - const_int_bound_a->min_value <= max_value_of_dtype + 1) { - // The goal is to write floordiv(a,b) in terms of truncdiv, without using - // negative operands. - // - // For any integer c - // - // floordiv(a,b) == floordiv(a + b*c - b*c, b) - // == floordiv(a + b*c, b) - c - // - // Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of - // truncdiv as follows. - // - // c == ceildiv(-a_min,b) - // == floordiv(-a_min + (b-1), b) - // == truncdiv(-a_min + (b-1), b) - // - // When substituted into `a + b*c`, this results in a positive argument. - // - // a + b*c - // == a + b*ceildiv(-a_min,b) - // == a - b*floordiv(a_min,b) - // >= a - b*floordiv(a,b) - // == floormod(a, b) - // >= 0 - // - // Since the argument is positive, this allows floordiv to be written as - // followed. - // - // floordiv(a,b) - // == floordiv(a + b*c, b) - c - // == truncdiv(a + b*c, b) - c - IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value); - PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b); - PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv); - return truncdiv(offset_numerator, op->b) - ceildiv; + if (const IntImmNode* b_as_intimm = op->b.as()) { + int64_t b_value = b_as_intimm->value; + if (auto opt_c_value = TryFindShiftCoefficientForPositiveRange(op->a, b_value)) { + int64_t c_value = *opt_c_value; + // now we can safely lower to truncdiv + return truncdiv(op->a + make_const(dtype, b_value * c_value), op->b) - + make_const(dtype, c_value); + } } - DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident"; PrimExpr rdiv = truncdiv(op->a, op->b); PrimExpr rmod = truncmod(op->a, op->b); @@ -221,58 +177,14 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { if (analyzer_->CanProveGreaterEqual(op->a, 0)) { return truncmod(op->a, op->b); } - - // If the numerator's lower bound is known, express the floormod - // in terms of truncmod using only positive operands. - - // The optimization below rewrites expressions involving `-a_min + (b - 1)`. - // Without proper bounds checking, this expression may overflow the dtype - // maximum, leading to non-equivalent transformations. - // To ensure safety, we require: - // b_max - a_min <= max_value_of_dtype + 1 - // This provides a conservative upper bound that prevents overflow and - // preserves the original semantics. - arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(op->a); - arith::ConstIntBound const_int_bound_b = analyzer_->const_int_bound(op->b); - const int64_t max_value_of_dtype = - Downcast(tvm::max_value(op->a->dtype.element_of()))->value; - if (const_int_bound_a->min_value < 0 && - const_int_bound_b->max_value - const_int_bound_a->min_value <= max_value_of_dtype + 1) { - // The goal is to write floormod(a,b) in terms of truncdiv and truncmod, - // without using negative operands. - // - // For any integer c - // - // floormod(a, b) == floormod(a + b*c, b) - // - // Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of - // truncdiv as follows. - // - // c == ceildiv(-a_min,b) - // == floordiv(-a_min + (b-1), b) - // == truncdiv(-a_min + (b-1), b) - // - // When substituted into `a + b*c`, this results in a positive argument. - // - // a + b*c - // == a + b*ceildiv(-a_min,b) - // == a - b*floordiv(a_min,b) - // >= a - b*floordiv(a,b) - // == floormod(a, b) - // >= 0 - // - // Since the argument is positive, this allows floordiv to be written as - // followed. - // - // floormod(a,b) - // == floormod(a + b*c, b) - // == truncmod(a + b*c, b) - IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value); - PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b); - PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv); - return truncmod(offset_numerator, op->b); + if (const IntImmNode* b_as_intimm = op->b.as()) { + int64_t b_value = b_as_intimm->value; + if (auto opt_c_value = TryFindShiftCoefficientForPositiveRange(op->a, b_value)) { + int64_t c_value = *opt_c_value; + // floormod(a, b) == floormod(a + b*c, b) == truncmod(a + b*c, b) + return truncmod(op->a + make_const(dtype, c_value * b_value), op->b); + } } - DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident"; // NOTE:condition on b >= 0. // mod(a, b) < 0 will imply we are doing ceildiv, @@ -388,6 +300,49 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return IRMutatorWithAnalyzer::VisitExpr_(op); } + /*! + * \brief Try to find a shift co-efficient c such that a + b*c positive and does not overflow. + * + * \param a the dividend + * \param b_value the divisor + * \return the shift co-efficient c, or nullopt if not found + */ + std::optional TryFindShiftCoefficientForPositiveRange(const PrimExpr& a, + int64_t b_value) { + if (b_value <= 0) { + return std::nullopt; + } + // NOTE: we need to be very careful in the checks below, to make sure + // all the intermediate calculations in both compiler checks and runtime checks + // do not overflow + arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(a); + if (const_int_bound_a->min_value >= 0) { + return std::nullopt; + } + const int64_t max_value_of_dtype = + Downcast(tvm::max_value(a->dtype.element_of()))->value; + + // NOTE: ensures that (b-1) - a_min does not overflow + // also note: max_value_of_dtype + const_int_bound_a->min_value won't overflow + // since a_min is negative, adding it to a positive value will not overflow + if (b_value - 1 > max_value_of_dtype + const_int_bound_a->min_value) { + return std::nullopt; + } + int64_t c_value = ((b_value - 1) - const_int_bound_a->min_value) / b_value; + ICHECK_GT(c_value, 0); + // NOTE: the c_value * b_value risks in overflow + if (c_value > max_value_of_dtype / b_value) return std::nullopt; + // need to check if the offset numerator will overflow + // to ensure if don't overflow, we need to use max_value_of_dtype - b_value * c_value + // note that b_value * c_value is positive, max_value_of_dtype is also positive, so the + // subtraction will not overflow + if (const_int_bound_a->max_value > max_value_of_dtype - b_value * c_value) { + // a + b * c risks overflow + return std::nullopt; + } + return c_value; + } + // attribute maps, shared only when FLegalize == FLowerIntrinsic std::vector> attr_maps_; FLowerGeneral fma_{nullptr}; diff --git a/tests/python/tir-transform/test_tir_transform_lower_intrin.py b/tests/python/tir-transform/test_tir_transform_lower_intrin.py index 63f37e6f4179..a0a6ab2508f6 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_intrin.py +++ b/tests/python/tir-transform/test_tir_transform_lower_intrin.py @@ -35,24 +35,35 @@ def lower_intrin(params, stmt): return stmt.value if lower_expr else stmt.body -def check_value(expr, vx, vy, data, fref): +def check_value(expr, variables, data, fref): + """ + Check that expr evaluates to fref(*row) for each row in data. + variables: list of TIR vars [x] or [x, y] bound to the columns of data. + data: list of tuples, each tuple has len(variables) elements. + """ n = len(data) - A = te.placeholder((n,), name="A", dtype=expr.dtype) - B = te.placeholder((n,), name="B", dtype=expr.dtype) + num_vars = len(variables) + assert num_vars >= 1 and all(len(row) == num_vars for row in data) + + placeholders = [ + te.placeholder((n,), name=f"v{i}", dtype=variables[i].dtype) for i in range(num_vars) + ] def make_binds(i): x = expr - x = tvm.tir.Let(vx, A[i], x) - x = tvm.tir.Let(vy, B[i], x) + for j in range(num_vars - 1, -1, -1): + x = tvm.tir.Let(variables[j], placeholders[j][i], x) return x C = te.compute((n,), make_binds) - f = tvm.compile(te.create_prim_func([A, B, C]), "llvm") - a = tvm.runtime.tensor(np.array([x for x, y in data], dtype=expr.dtype)) - b = tvm.runtime.tensor(np.array([y for x, y in data], dtype=expr.dtype)) - c = tvm.runtime.tensor(np.zeros(len(data), dtype=expr.dtype)) - f(a, b, c) - cref = np.array([fref(x, y) for x, y in data]) + f = tvm.compile(te.create_prim_func(placeholders + [C]), "llvm") + arrays = [ + tvm.runtime.tensor(np.array([row[j] for row in data], dtype=variables[j].dtype)) + for j in range(num_vars) + ] + c = tvm.runtime.tensor(np.zeros(n, dtype=expr.dtype)) + f(*arrays, c) + cref = np.array([fref(*row) for row in data]) np.testing.assert_equal(c.numpy(), cref) @@ -75,29 +86,29 @@ def test_lower_floordiv(): zero = tvm.tir.const(0, dtype) # no constraints res = lower_intrin([x, y], tvm.te.floordiv(x, y)) - check_value(res, x, y, data, lambda a, b: a // b) + check_value(res, [x, y], data, lambda a, b: a // b) # rhs >= 0 res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floordiv(x, y), zero)) - check_value(res, x, y, data, lambda a, b: a // b if b > 0 else 0) + check_value(res, [x, y], data, lambda a, b: a // b if b > 0 else 0) # involves max res = lower_intrin( [x, y], tvm.tir.Select(y >= 0, tvm.te.max(tvm.te.floordiv(x, y), zero), zero) ) - check_value(res, x, y, data, lambda a, b: max(a // b, 0) if b > 0 else 0) + check_value(res, [x, y], data, lambda a, b: max(a // b, 0) if b > 0 else 0) # lhs >= 0 res = lower_intrin( [x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floordiv(x, y), zero) ) - check_value(res, x, y, data, lambda a, b: a // b if b > 0 and a >= 0 else 0) + check_value(res, [x, y], data, lambda a, b: a // b if b > 0 and a >= 0 else 0) # const power of two res = lower_intrin([x, y], tvm.te.floordiv(x, tvm.tir.const(8, dtype=dtype))) - check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a // b) + check_value(res, [x, y], [(a, b) for a, b in data if b == 8], lambda a, b: a // b) # floordiv(x + m, k), m and k are positive constant. 2 <= m <= k-1. res = lower_intrin( [x, y], tvm.te.floordiv(x + tvm.tir.const(4, dtype=dtype), tvm.tir.const(5, dtype=dtype)), ) - check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) // b) + check_value(res, [x, y], [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) // b) @tvm.testing.requires_llvm @@ -109,26 +120,60 @@ def test_lower_floormod(): zero = tvm.tir.const(0, dtype) # no constraints res = lower_intrin([x, y], tvm.te.floormod(x, y)) - check_value(res, x, y, data, lambda a, b: a % b) + check_value(res, [x, y], data, lambda a, b: a % b) # rhs >= 0 res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floormod(x, y), zero)) - check_value(res, x, y, data, lambda a, b: a % b if b > 0 else 0) + check_value(res, [x, y], data, lambda a, b: a % b if b > 0 else 0) # lhs >= 0 res = lower_intrin( [x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floormod(x, y), zero) ) - check_value(res, x, y, data, lambda a, b: a % b if b > 0 and a >= 0 else 0) + check_value(res, [x, y], data, lambda a, b: a % b if b > 0 and a >= 0 else 0) # const power of two res = lower_intrin([x, y], tvm.te.floormod(x, tvm.tir.const(8, dtype=dtype))) - check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a % b) + check_value(res, [x, y], [(a, b) for a, b in data if b == 8], lambda a, b: a % b) # floormod(x + m, k), m and k are positive constant. 2 <= m <= k-1. res = lower_intrin( [x, y], tvm.te.floormod(x + tvm.tir.const(4, dtype=dtype), tvm.tir.const(5, dtype=dtype)), ) - check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) % b) + check_value(res, [x, y], [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) % b) + + +@tvm.testing.requires_llvm +def test_lower_floordiv_overflow_checks(): + """ + Regression tests for overflow checks in TryFindShiftCoefficientForPositiveRange. + Divisor is constant 3 (not 1 to avoid CSE, not power-of-two so we don't take the shift path). + Reuses lower_intrin and check_value; overflow tests use one var [x]. + """ + # Check 3: (b-1) - a_min must not overflow (numerator and C++ int64). + # x (int64) full range -> min_value = -2^63. With b = 3: numerator = 2 - (-2^63) > LLONG_MAX. + x = te.var("x", dtype="int64") + res = lower_intrin([x], tvm.te.floordiv(x, tvm.tir.const(3, "int64"))) + data_check3 = [(-(2**63),), (0,), (100,)] + check_value(res, [x], data_check3, lambda a: a // 3) + + # Check 4: c_value * b_value must not overflow dtype. + # x (int16) full range -> min_value = -32768, c = ceil(32770/3) = 10923; 10923*3 > 32767. + x = te.var("x", dtype="int16") + res = lower_intrin([x], tvm.te.floordiv(x, tvm.tir.const(3, "int16"))) + data_check4 = [(-32768,), (0,), (100,)] + check_value(res, [x], data_check4, lambda a: a // 3) + + # Check 5: a_max + b*c must not overflow (offset numerator). + # tir.min(tir.max(x, -10), 32758) can give bounds [-10, 32758]; b=3, c=4; a_max + 12 > 32767. + # In practice this path may not be triggered. This test still validates correct lowering. + x = te.var("x", dtype="int16") + clamped = tvm.tir.min( + tvm.tir.max(x, tvm.tir.const(-10, "int16")), tvm.tir.const(32758, "int16") + ) + res = lower_intrin([x], tvm.te.floordiv(clamped, tvm.tir.const(3, "int16"))) + data_check5 = [(-10,), (0,), (32758,), (32757,)] + check_value(res, [x], data_check5, lambda a: (min(max(a, -10), 32758)) // 3) if __name__ == "__main__": test_lower_floordiv() test_lower_floormod() + test_lower_floordiv_overflow_checks()