Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 58 additions & 103 deletions src/tir/transforms/lower_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,59 +115,15 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
if (analyzer_->CanProveGreaterEqual(op->a, 0) || analyzer_->CanProveGreaterEqual(e, 0)) {
return truncdiv(op->a, op->b);
}

// If the numerator's lower bound is known, express the floordiv
// in terms of truncdiv using only positive operands.

// The optimization below rewrites expressions involving `-a_min + (b - 1)`.
// Without proper bounds checking, this expression may overflow the dtype
// maximum, leading to non-equivalent transformations.
// To ensure safety, we require:
// b_max - a_min <= max_value_of_dtype + 1
// This provides a conservative upper bound that prevents overflow and
// preserves the original semantics.
arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(op->a);
arith::ConstIntBound const_int_bound_b = analyzer_->const_int_bound(op->b);
const int64_t max_value_of_dtype =
Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value;
if (const_int_bound_a->min_value < 0 &&
const_int_bound_b->max_value - const_int_bound_a->min_value <= max_value_of_dtype + 1) {
// The goal is to write floordiv(a,b) in terms of truncdiv, without using
// negative operands.
//
// For any integer c
//
// floordiv(a,b) == floordiv(a + b*c - b*c, b)
// == floordiv(a + b*c, b) - c
//
// Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of
// truncdiv as follows.
//
// c == ceildiv(-a_min,b)
// == floordiv(-a_min + (b-1), b)
// == truncdiv(-a_min + (b-1), b)
//
// When substituted into `a + b*c`, this results in a positive argument.
//
// a + b*c
// == a + b*ceildiv(-a_min,b)
// == a - b*floordiv(a_min,b)
// >= a - b*floordiv(a,b)
// == floormod(a, b)
// >= 0
//
// Since the argument is positive, this allows floordiv to be written as
// followed.
//
// floordiv(a,b)
// == floordiv(a + b*c, b) - c
// == truncdiv(a + b*c, b) - c
IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value);
PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);
return truncdiv(offset_numerator, op->b) - ceildiv;
if (const IntImmNode* b_as_intimm = op->b.as<IntImmNode>()) {
int64_t b_value = b_as_intimm->value;
if (auto opt_c_value = TryFindShiftCoefficientForPositiveRange(op->a, b_value)) {
int64_t c_value = *opt_c_value;
// now we can safely lower to truncdiv
return truncdiv(op->a + make_const(dtype, b_value * c_value), op->b) -
make_const(dtype, c_value);
}
}

DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
PrimExpr rdiv = truncdiv(op->a, op->b);
PrimExpr rmod = truncmod(op->a, op->b);
Expand Down Expand Up @@ -221,58 +177,14 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
if (analyzer_->CanProveGreaterEqual(op->a, 0)) {
return truncmod(op->a, op->b);
}

// If the numerator's lower bound is known, express the floormod
// in terms of truncmod using only positive operands.

// The optimization below rewrites expressions involving `-a_min + (b - 1)`.
// Without proper bounds checking, this expression may overflow the dtype
// maximum, leading to non-equivalent transformations.
// To ensure safety, we require:
// b_max - a_min <= max_value_of_dtype + 1
// This provides a conservative upper bound that prevents overflow and
// preserves the original semantics.
arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(op->a);
arith::ConstIntBound const_int_bound_b = analyzer_->const_int_bound(op->b);
const int64_t max_value_of_dtype =
Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value;
if (const_int_bound_a->min_value < 0 &&
const_int_bound_b->max_value - const_int_bound_a->min_value <= max_value_of_dtype + 1) {
// The goal is to write floormod(a,b) in terms of truncdiv and truncmod,
// without using negative operands.
//
// For any integer c
//
// floormod(a, b) == floormod(a + b*c, b)
//
// Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms of
// truncdiv as follows.
//
// c == ceildiv(-a_min,b)
// == floordiv(-a_min + (b-1), b)
// == truncdiv(-a_min + (b-1), b)
//
// When substituted into `a + b*c`, this results in a positive argument.
//
// a + b*c
// == a + b*ceildiv(-a_min,b)
// == a - b*floordiv(a_min,b)
// >= a - b*floordiv(a,b)
// == floormod(a, b)
// >= 0
//
// Since the argument is positive, this allows floordiv to be written as
// followed.
//
// floormod(a,b)
// == floormod(a + b*c, b)
// == truncmod(a + b*c, b)
IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value);
PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b);
PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b * ceildiv);
return truncmod(offset_numerator, op->b);
if (const IntImmNode* b_as_intimm = op->b.as<IntImmNode>()) {
int64_t b_value = b_as_intimm->value;
if (auto opt_c_value = TryFindShiftCoefficientForPositiveRange(op->a, b_value)) {
int64_t c_value = *opt_c_value;
// floormod(a, b) == floormod(a + b*c, b) == truncmod(a + b*c, b)
return truncmod(op->a + make_const(dtype, c_value * b_value), op->b);
}
}

DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
// NOTE:condition on b >= 0.
// mod(a, b) < 0 will imply we are doing ceildiv,
Expand Down Expand Up @@ -388,6 +300,49 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
return IRMutatorWithAnalyzer::VisitExpr_(op);
}

/*!
* \brief Try to find a shift co-efficient c such that a + b*c positive and does not overflow.
*
* \param a the dividend
* \param b_value the divisor
* \return the shift co-efficient c, or nullopt if not found
*/
std::optional<int64_t> TryFindShiftCoefficientForPositiveRange(const PrimExpr& a,
int64_t b_value) {
if (b_value <= 0) {
return std::nullopt;
}
// NOTE: we need to be very careful in the checks below, to make sure
// all the intermediate calculations in both compiler checks and runtime checks
// do not overflow
arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(a);
if (const_int_bound_a->min_value >= 0) {
return std::nullopt;
}
const int64_t max_value_of_dtype =
Downcast<IntImm>(tvm::max_value(a->dtype.element_of()))->value;

// NOTE: ensures that (b-1) - a_min does not overflow
// also note: max_value_of_dtype + const_int_bound_a->min_value won't overflow
// since a_min is negative, adding it to a positive value will not overflow
if (b_value - 1 > max_value_of_dtype + const_int_bound_a->min_value) {
return std::nullopt;
}
int64_t c_value = ((b_value - 1) - const_int_bound_a->min_value) / b_value;
ICHECK_GT(c_value, 0);
// NOTE: the c_value * b_value risks in overflow
if (c_value > max_value_of_dtype / b_value) return std::nullopt;
// need to check if the offset numerator will overflow
// to ensure if don't overflow, we need to use max_value_of_dtype - b_value * c_value
// note that b_value * c_value is positive, max_value_of_dtype is also positive, so the
// subtraction will not overflow
if (const_int_bound_a->max_value > max_value_of_dtype - b_value * c_value) {
// a + b * c risks overflow
return std::nullopt;
}
return c_value;
}

// attribute maps, shared only when FLegalize == FLowerIntrinsic
std::vector<OpAttrMap<FLowerGeneral>> attr_maps_;
FLowerGeneral fma_{nullptr};
Expand Down
89 changes: 67 additions & 22 deletions tests/python/tir-transform/test_tir_transform_lower_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,35 @@ def lower_intrin(params, stmt):
return stmt.value if lower_expr else stmt.body


def check_value(expr, vx, vy, data, fref):
def check_value(expr, variables, data, fref):
"""
Check that expr evaluates to fref(*row) for each row in data.
variables: list of TIR vars [x] or [x, y] bound to the columns of data.
data: list of tuples, each tuple has len(variables) elements.
"""
n = len(data)
A = te.placeholder((n,), name="A", dtype=expr.dtype)
B = te.placeholder((n,), name="B", dtype=expr.dtype)
num_vars = len(variables)
assert num_vars >= 1 and all(len(row) == num_vars for row in data)

placeholders = [
te.placeholder((n,), name=f"v{i}", dtype=variables[i].dtype) for i in range(num_vars)
]

def make_binds(i):
x = expr
x = tvm.tir.Let(vx, A[i], x)
x = tvm.tir.Let(vy, B[i], x)
for j in range(num_vars - 1, -1, -1):
x = tvm.tir.Let(variables[j], placeholders[j][i], x)
return x

C = te.compute((n,), make_binds)
f = tvm.compile(te.create_prim_func([A, B, C]), "llvm")
a = tvm.runtime.tensor(np.array([x for x, y in data], dtype=expr.dtype))
b = tvm.runtime.tensor(np.array([y for x, y in data], dtype=expr.dtype))
c = tvm.runtime.tensor(np.zeros(len(data), dtype=expr.dtype))
f(a, b, c)
cref = np.array([fref(x, y) for x, y in data])
f = tvm.compile(te.create_prim_func(placeholders + [C]), "llvm")
arrays = [
tvm.runtime.tensor(np.array([row[j] for row in data], dtype=variables[j].dtype))
for j in range(num_vars)
]
c = tvm.runtime.tensor(np.zeros(n, dtype=expr.dtype))
f(*arrays, c)
cref = np.array([fref(*row) for row in data])
np.testing.assert_equal(c.numpy(), cref)


Expand All @@ -75,29 +86,29 @@ def test_lower_floordiv():
zero = tvm.tir.const(0, dtype)
# no constraints
res = lower_intrin([x, y], tvm.te.floordiv(x, y))
check_value(res, x, y, data, lambda a, b: a // b)
check_value(res, [x, y], data, lambda a, b: a // b)
# rhs >= 0
res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floordiv(x, y), zero))
check_value(res, x, y, data, lambda a, b: a // b if b > 0 else 0)
check_value(res, [x, y], data, lambda a, b: a // b if b > 0 else 0)
# involves max
res = lower_intrin(
[x, y], tvm.tir.Select(y >= 0, tvm.te.max(tvm.te.floordiv(x, y), zero), zero)
)
check_value(res, x, y, data, lambda a, b: max(a // b, 0) if b > 0 else 0)
check_value(res, [x, y], data, lambda a, b: max(a // b, 0) if b > 0 else 0)
# lhs >= 0
res = lower_intrin(
[x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floordiv(x, y), zero)
)
check_value(res, x, y, data, lambda a, b: a // b if b > 0 and a >= 0 else 0)
check_value(res, [x, y], data, lambda a, b: a // b if b > 0 and a >= 0 else 0)
# const power of two
res = lower_intrin([x, y], tvm.te.floordiv(x, tvm.tir.const(8, dtype=dtype)))
check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a // b)
check_value(res, [x, y], [(a, b) for a, b in data if b == 8], lambda a, b: a // b)
# floordiv(x + m, k), m and k are positive constant. 2 <= m <= k-1.
res = lower_intrin(
[x, y],
tvm.te.floordiv(x + tvm.tir.const(4, dtype=dtype), tvm.tir.const(5, dtype=dtype)),
)
check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) // b)
check_value(res, [x, y], [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) // b)


@tvm.testing.requires_llvm
Expand All @@ -109,26 +120,60 @@ def test_lower_floormod():
zero = tvm.tir.const(0, dtype)
# no constraints
res = lower_intrin([x, y], tvm.te.floormod(x, y))
check_value(res, x, y, data, lambda a, b: a % b)
check_value(res, [x, y], data, lambda a, b: a % b)
# rhs >= 0
res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floormod(x, y), zero))
check_value(res, x, y, data, lambda a, b: a % b if b > 0 else 0)
check_value(res, [x, y], data, lambda a, b: a % b if b > 0 else 0)
# lhs >= 0
res = lower_intrin(
[x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.te.floormod(x, y), zero)
)
check_value(res, x, y, data, lambda a, b: a % b if b > 0 and a >= 0 else 0)
check_value(res, [x, y], data, lambda a, b: a % b if b > 0 and a >= 0 else 0)
# const power of two
res = lower_intrin([x, y], tvm.te.floormod(x, tvm.tir.const(8, dtype=dtype)))
check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a, b: a % b)
check_value(res, [x, y], [(a, b) for a, b in data if b == 8], lambda a, b: a % b)
# floormod(x + m, k), m and k are positive constant. 2 <= m <= k-1.
res = lower_intrin(
[x, y],
tvm.te.floormod(x + tvm.tir.const(4, dtype=dtype), tvm.tir.const(5, dtype=dtype)),
)
check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) % b)
check_value(res, [x, y], [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) % b)


@tvm.testing.requires_llvm
def test_lower_floordiv_overflow_checks():
"""
Regression tests for overflow checks in TryFindShiftCoefficientForPositiveRange.
Divisor is constant 3 (not 1 to avoid CSE, not power-of-two so we don't take the shift path).
Reuses lower_intrin and check_value; overflow tests use one var [x].
"""
# Check 3: (b-1) - a_min must not overflow (numerator and C++ int64).
# x (int64) full range -> min_value = -2^63. With b = 3: numerator = 2 - (-2^63) > LLONG_MAX.
x = te.var("x", dtype="int64")
res = lower_intrin([x], tvm.te.floordiv(x, tvm.tir.const(3, "int64")))
data_check3 = [(-(2**63),), (0,), (100,)]
check_value(res, [x], data_check3, lambda a: a // 3)

# Check 4: c_value * b_value must not overflow dtype.
# x (int16) full range -> min_value = -32768, c = ceil(32770/3) = 10923; 10923*3 > 32767.
x = te.var("x", dtype="int16")
res = lower_intrin([x], tvm.te.floordiv(x, tvm.tir.const(3, "int16")))
data_check4 = [(-32768,), (0,), (100,)]
check_value(res, [x], data_check4, lambda a: a // 3)

# Check 5: a_max + b*c must not overflow (offset numerator).
# tir.min(tir.max(x, -10), 32758) can give bounds [-10, 32758]; b=3, c=4; a_max + 12 > 32767.
# In practice this path may not be triggered. This test still validates correct lowering.
x = te.var("x", dtype="int16")
clamped = tvm.tir.min(
tvm.tir.max(x, tvm.tir.const(-10, "int16")), tvm.tir.const(32758, "int16")
)
res = lower_intrin([x], tvm.te.floordiv(clamped, tvm.tir.const(3, "int16")))
data_check5 = [(-10,), (0,), (32758,), (32757,)]
check_value(res, [x], data_check5, lambda a: (min(max(a, -10), 32758)) // 3)


if __name__ == "__main__":
test_lower_floordiv()
test_lower_floormod()
test_lower_floordiv_overflow_checks()
Loading